This commit is contained in:
110
crates/kreuzberg/src/core/batch_mode.rs
Normal file
110
crates/kreuzberg/src/core/batch_mode.rs
Normal file
@@ -0,0 +1,110 @@
|
||||
//! Internal batch mode tracking using tokio task-local storage.
|
||||
//!
|
||||
//! This module provides a way to track whether we're in batch processing mode
|
||||
//! without exposing it in the public API. Extractors check this flag to decide
|
||||
//! whether to use `spawn_blocking` for CPU-intensive work.
|
||||
|
||||
use std::cell::Cell;
|
||||
use tokio::task_local;
|
||||
|
||||
task_local! {
|
||||
/// Task-local flag indicating batch processing mode.
|
||||
///
|
||||
/// When true, extractors use `spawn_blocking` for CPU-intensive work to enable
|
||||
/// parallelism. When false (single-file mode), extractors run directly to avoid
|
||||
/// spawn overhead.
|
||||
static BATCH_MODE: Cell<bool>;
|
||||
}
|
||||
|
||||
/// Check if we're currently in batch processing mode.
|
||||
///
|
||||
/// Returns `false` if the task-local is not set (single-file mode).
|
||||
#[cfg(any(
|
||||
feature = "pdf",
|
||||
feature = "office",
|
||||
feature = "excel",
|
||||
feature = "excel-wasm",
|
||||
feature = "archives"
|
||||
))]
|
||||
pub(crate) fn is_batch_mode() -> bool {
|
||||
BATCH_MODE.try_with(|cell| cell.get()).unwrap_or(false)
|
||||
}
|
||||
|
||||
/// Run a future with batch mode enabled.
|
||||
///
|
||||
/// This sets the task-local BATCH_MODE flag for the duration of the future.
|
||||
pub(crate) async fn with_batch_mode<F, T>(future: F) -> T
|
||||
where
|
||||
F: std::future::Future<Output = T>,
|
||||
{
|
||||
BATCH_MODE.scope(Cell::new(true), future).await
|
||||
}
|
||||
|
||||
#[cfg(all(
|
||||
test,
|
||||
any(
|
||||
feature = "pdf",
|
||||
feature = "office",
|
||||
feature = "excel",
|
||||
feature = "excel-wasm",
|
||||
feature = "archives"
|
||||
)
|
||||
))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_mode_not_set_by_default() {
|
||||
let result = is_batch_mode();
|
||||
assert!(!result, "batch mode should be false by default");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_batch_mode_sets_flag() {
|
||||
let result = with_batch_mode(async { is_batch_mode() }).await;
|
||||
|
||||
assert!(result, "batch mode should be true inside with_batch_mode");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_mode_scoped_to_future() {
|
||||
assert!(!is_batch_mode(), "batch mode should be false before");
|
||||
|
||||
with_batch_mode(async {
|
||||
assert!(is_batch_mode(), "batch mode should be true inside");
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(!is_batch_mode(), "batch mode should be false after future completes");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_batch_mode_calls() {
|
||||
let result = with_batch_mode(async {
|
||||
let outer = is_batch_mode();
|
||||
let inner = with_batch_mode(async { is_batch_mode() }).await;
|
||||
(outer, inner)
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(result.0, "outer batch mode should be true");
|
||||
assert!(result.1, "inner batch mode should be true");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_mode_unaffected_after_with_batch_mode() {
|
||||
with_batch_mode(async {
|
||||
assert!(is_batch_mode(), "first call should set batch mode");
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(!is_batch_mode(), "batch mode should be false between calls");
|
||||
|
||||
with_batch_mode(async {
|
||||
assert!(is_batch_mode(), "second call should set batch mode");
|
||||
})
|
||||
.await;
|
||||
|
||||
assert!(!is_batch_mode(), "batch mode should be false after all calls");
|
||||
}
|
||||
}
|
||||
315
crates/kreuzberg/src/core/batch_optimizations.rs
Normal file
315
crates/kreuzberg/src/core/batch_optimizations.rs
Normal file
@@ -0,0 +1,315 @@
|
||||
//! Batch extraction optimizations using object pooling.
|
||||
//!
|
||||
//! This module provides optimized batch processing utilities that leverage
|
||||
//! object pooling to reduce allocations during concurrent extraction of
|
||||
//! multiple documents.
|
||||
//!
|
||||
//! # Performance Impact
|
||||
//!
|
||||
//! - Reuses temporary string/buffer allocations across documents
|
||||
//! - Reduces garbage collection pressure by ~5-10%
|
||||
//! - Overall throughput improvement of 5-10% for batch operations
|
||||
//!
|
||||
//! # Usage
|
||||
//!
|
||||
//! The batch extraction functions automatically use pooling internally.
|
||||
//! For manual control, use `BatchProcessor` to create pools and manage
|
||||
//! extraction with custom pool sizes.
|
||||
|
||||
use crate::utils::pool::{ByteBufferPool, StringBufferPool, create_byte_buffer_pool, create_string_buffer_pool};
|
||||
use crate::utils::pool_sizing::PoolSizeHint;
|
||||
use parking_lot::Mutex;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
|
||||
/// Configuration for batch processing with pooling optimizations.
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct BatchProcessorConfig {
|
||||
/// Maximum number of string buffers to maintain in the pool
|
||||
pub string_pool_size: usize,
|
||||
|
||||
/// Initial capacity for pooled string buffers in bytes
|
||||
pub string_buffer_capacity: usize,
|
||||
|
||||
/// Maximum number of byte buffers to maintain in the pool
|
||||
pub byte_pool_size: usize,
|
||||
|
||||
/// Initial capacity for pooled byte buffers in bytes
|
||||
pub byte_buffer_capacity: usize,
|
||||
|
||||
/// Maximum concurrent extractions (for concurrency control)
|
||||
pub max_concurrent: Option<usize>,
|
||||
}
|
||||
|
||||
impl Default for BatchProcessorConfig {
|
||||
fn default() -> Self {
|
||||
BatchProcessorConfig {
|
||||
string_pool_size: 10,
|
||||
string_buffer_capacity: 8192,
|
||||
byte_pool_size: 10,
|
||||
byte_buffer_capacity: 65536,
|
||||
max_concurrent: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Batch processor that manages object pools for optimized extraction.
|
||||
///
|
||||
/// This struct manages the lifecycle of reusable object pools used during
|
||||
/// batch extraction. Pools are created lazily on first use and reused across
|
||||
/// all documents processed by this batch processor.
|
||||
///
|
||||
/// # Lazy Initialization
|
||||
///
|
||||
/// Pools are initialized on demand to reduce memory usage for applications
|
||||
/// that may not use batch processing immediately or at all.
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub struct BatchProcessor {
|
||||
string_pool: Mutex<Option<Arc<StringBufferPool>>>,
|
||||
byte_pool: Mutex<Option<Arc<ByteBufferPool>>>,
|
||||
config: BatchProcessorConfig,
|
||||
string_pool_initialized: AtomicBool,
|
||||
byte_pool_initialized: AtomicBool,
|
||||
}
|
||||
|
||||
impl BatchProcessor {
|
||||
/// Create a new batch processor with default pool configuration.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `BatchProcessor` ready to process documents.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::batch_optimizations::BatchProcessor;
|
||||
///
|
||||
/// let processor = BatchProcessor::new();
|
||||
/// ```
|
||||
pub fn new() -> Self {
|
||||
Self::with_config(BatchProcessorConfig::default())
|
||||
}
|
||||
|
||||
/// Create a new batch processor with custom pool configuration.
|
||||
///
|
||||
/// Pools are not created immediately but lazily on first access.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `config` - Custom batch processor configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `BatchProcessor` configured with the provided settings.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::batch_optimizations::{BatchProcessor, BatchProcessorConfig};
|
||||
///
|
||||
/// let mut config = BatchProcessorConfig::default();
|
||||
/// config.string_pool_size = 20;
|
||||
/// config.string_buffer_capacity = 16384;
|
||||
/// let processor = BatchProcessor::with_config(config);
|
||||
/// ```
|
||||
pub fn with_config(config: BatchProcessorConfig) -> Self {
|
||||
BatchProcessor {
|
||||
string_pool: Mutex::new(None),
|
||||
byte_pool: Mutex::new(None),
|
||||
config,
|
||||
string_pool_initialized: AtomicBool::new(false),
|
||||
byte_pool_initialized: AtomicBool::new(false),
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a batch processor with pool sizes optimized for a specific document.
|
||||
///
|
||||
/// This method uses a `PoolSizeHint` (derived from file size and MIME type)
|
||||
/// to create a batch processor with appropriately sized pools. This reduces
|
||||
/// memory waste by tailoring pool allocation to actual document complexity.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `hint` - Pool sizing hint containing recommended buffer counts and capacities
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A new `BatchProcessor` configured with the hint-based pool sizes
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::core::batch_optimizations::BatchProcessor;
|
||||
/// use kreuzberg::utils::pool_sizing::estimate_pool_size;
|
||||
///
|
||||
/// let hint = estimate_pool_size(5_000_000, "application/pdf");
|
||||
/// let processor = BatchProcessor::with_pool_hint(&hint);
|
||||
/// ```
|
||||
pub fn with_pool_hint(hint: &PoolSizeHint) -> Self {
|
||||
let config = BatchProcessorConfig {
|
||||
string_pool_size: hint.string_buffer_count,
|
||||
string_buffer_capacity: hint.string_buffer_capacity,
|
||||
byte_pool_size: hint.byte_buffer_count,
|
||||
byte_buffer_capacity: hint.byte_buffer_capacity,
|
||||
max_concurrent: None,
|
||||
};
|
||||
Self::with_config(config)
|
||||
}
|
||||
|
||||
/// Get a reference to the string buffer pool.
|
||||
///
|
||||
/// Creates the pool lazily on first access.
|
||||
/// Useful for custom pooling implementations that need direct pool access.
|
||||
pub fn string_pool(&self) -> Arc<StringBufferPool> {
|
||||
if self.string_pool_initialized.load(Ordering::Acquire) {
|
||||
return Arc::clone(self.string_pool.lock().as_ref().unwrap());
|
||||
}
|
||||
|
||||
let mut pool_opt = self.string_pool.lock();
|
||||
if pool_opt.is_none() {
|
||||
let pool = Arc::new(create_string_buffer_pool(
|
||||
self.config.string_pool_size,
|
||||
self.config.string_buffer_capacity,
|
||||
));
|
||||
*pool_opt = Some(pool);
|
||||
self.string_pool_initialized.store(true, Ordering::Release);
|
||||
}
|
||||
|
||||
Arc::clone(pool_opt.as_ref().unwrap())
|
||||
}
|
||||
|
||||
/// Get a reference to the byte buffer pool.
|
||||
///
|
||||
/// Creates the pool lazily on first access.
|
||||
/// Useful for custom pooling implementations that need direct pool access.
|
||||
pub fn byte_pool(&self) -> Arc<ByteBufferPool> {
|
||||
if self.byte_pool_initialized.load(Ordering::Acquire) {
|
||||
return Arc::clone(self.byte_pool.lock().as_ref().unwrap());
|
||||
}
|
||||
|
||||
let mut pool_opt = self.byte_pool.lock();
|
||||
if pool_opt.is_none() {
|
||||
let pool = Arc::new(create_byte_buffer_pool(
|
||||
self.config.byte_pool_size,
|
||||
self.config.byte_buffer_capacity,
|
||||
));
|
||||
*pool_opt = Some(pool);
|
||||
self.byte_pool_initialized.store(true, Ordering::Release);
|
||||
}
|
||||
|
||||
Arc::clone(pool_opt.as_ref().unwrap())
|
||||
}
|
||||
|
||||
/// Get the current configuration.
|
||||
pub fn config(&self) -> &BatchProcessorConfig {
|
||||
&self.config
|
||||
}
|
||||
|
||||
/// Get the number of pooled string buffers currently available.
|
||||
pub fn string_pool_size(&self) -> usize {
|
||||
self.string_pool.lock().as_ref().map(|p| p.size()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Get the number of pooled byte buffers currently available.
|
||||
pub fn byte_pool_size(&self) -> usize {
|
||||
self.byte_pool.lock().as_ref().map(|p| p.size()).unwrap_or(0)
|
||||
}
|
||||
|
||||
/// Clear all pooled objects, forcing new allocations on next acquire.
|
||||
///
|
||||
/// Useful for memory-constrained environments or to reclaim memory
|
||||
/// after processing large batches.
|
||||
pub fn clear_pools(&self) {
|
||||
if let Some(pool) = self.string_pool.lock().as_ref() {
|
||||
pool.clear();
|
||||
}
|
||||
if let Some(pool) = self.byte_pool.lock().as_ref() {
|
||||
pool.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for BatchProcessor {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_batch_processor_creation() {
|
||||
let processor = BatchProcessor::new();
|
||||
assert_eq!(processor.string_pool_size(), 0);
|
||||
assert_eq!(processor.byte_pool_size(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_processor_with_config() {
|
||||
let config = BatchProcessorConfig {
|
||||
string_pool_size: 5,
|
||||
string_buffer_capacity: 1024,
|
||||
byte_pool_size: 3,
|
||||
byte_buffer_capacity: 4096,
|
||||
max_concurrent: None,
|
||||
};
|
||||
|
||||
let processor = BatchProcessor::with_config(config);
|
||||
assert_eq!(processor.config().string_pool_size, 5);
|
||||
assert_eq!(processor.config().byte_pool_size, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_processor_string_pool_usage() {
|
||||
let processor = BatchProcessor::new();
|
||||
let pool = processor.string_pool();
|
||||
|
||||
{
|
||||
let mut s = pool.acquire();
|
||||
s.push_str("test");
|
||||
}
|
||||
|
||||
{
|
||||
let s = pool.acquire();
|
||||
assert_eq!(s.len(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_processor_byte_pool_usage() {
|
||||
let processor = BatchProcessor::new();
|
||||
let pool = processor.byte_pool();
|
||||
|
||||
{
|
||||
let mut buf = pool.acquire();
|
||||
buf.extend_from_slice(b"test");
|
||||
}
|
||||
|
||||
{
|
||||
let buf = pool.acquire();
|
||||
assert_eq!(buf.len(), 0);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_processor_clear_pools() {
|
||||
let processor = BatchProcessor::new();
|
||||
|
||||
let s1 = processor.string_pool().acquire();
|
||||
let s2 = processor.byte_pool().acquire();
|
||||
|
||||
drop(s1);
|
||||
drop(s2);
|
||||
|
||||
assert!(processor.string_pool_size() > 0);
|
||||
assert!(processor.byte_pool_size() > 0);
|
||||
|
||||
processor.clear_pools();
|
||||
|
||||
assert_eq!(processor.string_pool_size(), 0);
|
||||
assert_eq!(processor.byte_pool_size(), 0);
|
||||
}
|
||||
}
|
||||
55
crates/kreuzberg/src/core/config/acceleration.rs
Normal file
55
crates/kreuzberg/src/core/config/acceleration.rs
Normal file
@@ -0,0 +1,55 @@
|
||||
//! Acceleration configuration for ONNX Runtime execution providers.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Hardware acceleration configuration for ONNX Runtime models.
|
||||
///
|
||||
/// Controls which execution provider (CPU, CoreML, CUDA, TensorRT) is used
|
||||
/// for inference in layout detection and embedding generation.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::AccelerationConfig;
|
||||
///
|
||||
/// // Auto-select: CoreML on macOS, CUDA on Linux, CPU elsewhere
|
||||
/// let config = AccelerationConfig::default();
|
||||
///
|
||||
/// // Force CPU only
|
||||
/// let config = AccelerationConfig {
|
||||
/// provider: kreuzberg::ExecutionProviderType::Cpu,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
pub struct AccelerationConfig {
|
||||
/// Execution provider to use for ONNX inference.
|
||||
#[serde(default)]
|
||||
pub provider: ExecutionProviderType,
|
||||
|
||||
/// GPU device ID (for CUDA/TensorRT). Ignored for CPU/CoreML/Auto.
|
||||
#[serde(default)]
|
||||
pub device_id: u32,
|
||||
}
|
||||
|
||||
/// ONNX Runtime execution provider type.
|
||||
///
|
||||
/// Determines which hardware backend is used for model inference.
|
||||
/// `Auto` (default) selects the best available provider per platform.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default, PartialEq)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ExecutionProviderType {
|
||||
/// Auto-select: CoreML on macOS, CUDA on Linux, CPU elsewhere.
|
||||
#[default]
|
||||
Auto,
|
||||
/// CPU execution provider (always available).
|
||||
Cpu,
|
||||
/// Apple CoreML (macOS/iOS Neural Engine + GPU).
|
||||
#[serde(alias = "coreml")]
|
||||
CoreMl,
|
||||
/// NVIDIA CUDA GPU acceleration.
|
||||
Cuda,
|
||||
/// NVIDIA TensorRT (optimized CUDA inference).
|
||||
#[serde(alias = "tensorrt")]
|
||||
TensorRt,
|
||||
}
|
||||
140
crates/kreuzberg/src/core/config/concurrency.rs
Normal file
140
crates/kreuzberg/src/core/config/concurrency.rs
Normal file
@@ -0,0 +1,140 @@
|
||||
//! Concurrency and thread pool configuration.
|
||||
|
||||
use std::sync::Once;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Controls thread usage for constrained environments.
|
||||
///
|
||||
/// Set `max_threads` to cap all internal thread pools (Rayon, ONNX Runtime
|
||||
/// intra-op) and batch concurrency to a single limit.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config::ConcurrencyConfig;
|
||||
///
|
||||
/// let config = ConcurrencyConfig {
|
||||
/// max_threads: Some(2),
|
||||
/// };
|
||||
/// ```
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(default)]
|
||||
pub struct ConcurrencyConfig {
|
||||
/// Maximum number of threads for all internal thread pools.
|
||||
///
|
||||
/// Caps Rayon global pool size, ONNX Runtime intra-op threads, and
|
||||
/// (when `max_concurrent_extractions` is unset) the batch concurrency
|
||||
/// semaphore. When `None`, system defaults are used.
|
||||
pub max_threads: Option<usize>,
|
||||
}
|
||||
|
||||
static POOL_INIT: Once = Once::new();
|
||||
|
||||
/// Resolve the effective thread budget from config or auto-detection.
|
||||
///
|
||||
/// User-set `max_threads` takes priority. Otherwise auto-detects from `num_cpus`,
|
||||
/// capped at 8 for sane defaults in serverless environments.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::core::config::ConcurrencyConfig;
|
||||
/// use kreuzberg::core::config::concurrency::resolve_thread_budget;
|
||||
///
|
||||
/// let config = ConcurrencyConfig { max_threads: Some(4) };
|
||||
/// assert_eq!(resolve_thread_budget(Some(&config)), 4);
|
||||
/// assert!(resolve_thread_budget(None) >= 1);
|
||||
/// ```
|
||||
pub(crate) fn resolve_thread_budget(config: Option<&ConcurrencyConfig>) -> usize {
|
||||
if let Some(n) = config.and_then(|c| c.max_threads) {
|
||||
return n.max(1);
|
||||
}
|
||||
num_cpus::get().min(8)
|
||||
}
|
||||
|
||||
/// Initialize the global Rayon thread pool with the given budget.
|
||||
///
|
||||
/// Safe to call multiple times — only the first call takes effect (subsequent
|
||||
/// calls are silently ignored).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::core::config::concurrency::init_thread_pools;
|
||||
///
|
||||
/// init_thread_pools(4);
|
||||
/// init_thread_pools(2); // no-op: pool already initialized
|
||||
/// ```
|
||||
pub(crate) fn init_thread_pools(budget: usize) {
|
||||
POOL_INIT.call_once(|| {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
rayon::ThreadPoolBuilder::new().num_threads(budget).build_global().ok();
|
||||
#[cfg(target_arch = "wasm32")]
|
||||
let _ = budget;
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_resolve_thread_budget_none() {
|
||||
let budget = resolve_thread_budget(None);
|
||||
assert!(budget >= 1);
|
||||
assert!(budget <= 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_thread_budget_with_config() {
|
||||
let config = ConcurrencyConfig { max_threads: Some(4) };
|
||||
assert_eq!(resolve_thread_budget(Some(&config)), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_thread_budget_clamps_to_one() {
|
||||
let config = ConcurrencyConfig { max_threads: Some(0) };
|
||||
assert_eq!(resolve_thread_budget(Some(&config)), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_thread_budget_no_max() {
|
||||
let config = ConcurrencyConfig { max_threads: None };
|
||||
let budget = resolve_thread_budget(Some(&config));
|
||||
assert!(budget >= 1);
|
||||
assert!(budget <= 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_init_thread_pools_idempotent() {
|
||||
// Should not panic when called multiple times.
|
||||
init_thread_pools(2);
|
||||
init_thread_pools(4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default() {
|
||||
let config = ConcurrencyConfig::default();
|
||||
assert!(config.max_threads.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde_roundtrip() {
|
||||
let json = r#"{"max_threads": 2}"#;
|
||||
let config: ConcurrencyConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.max_threads, Some(2));
|
||||
|
||||
let serialized = serde_json::to_string(&config).unwrap();
|
||||
let roundtripped: ConcurrencyConfig = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(roundtripped.max_threads, Some(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde_empty() {
|
||||
let json = r#"{}"#;
|
||||
let config: ConcurrencyConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(config.max_threads.is_none());
|
||||
}
|
||||
}
|
||||
80
crates/kreuzberg/src/core/config/content_filter.rs
Normal file
80
crates/kreuzberg/src/core/config/content_filter.rs
Normal file
@@ -0,0 +1,80 @@
|
||||
//! Cross-extractor content filtering configuration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Cross-extractor content filtering configuration.
|
||||
///
|
||||
/// Controls whether "furniture" content (headers, footers, page numbers,
|
||||
/// watermarks, repeating text) is included in or stripped from extraction
|
||||
/// results. Applies across all extractors (PDF, DOCX, RTF, ODT, HTML, etc.)
|
||||
/// with format-specific implementation.
|
||||
///
|
||||
/// When `None` on `ExtractionConfig`, each extractor uses its current
|
||||
/// default behavior unchanged.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ContentFilterConfig {
|
||||
/// Include running headers in extraction output.
|
||||
///
|
||||
/// - PDF: Disables top-margin furniture stripping and prevents the layout
|
||||
/// model from treating `PageHeader`-classified regions as furniture.
|
||||
/// - DOCX: Includes document headers in text output.
|
||||
/// - RTF/ODT: Headers already included; this is a no-op when true.
|
||||
/// - HTML/EPUB: Keeps `<header>` element content.
|
||||
///
|
||||
/// Default: `false` (headers are stripped or excluded).
|
||||
#[serde(default)]
|
||||
pub include_headers: bool,
|
||||
|
||||
/// Include running footers in extraction output.
|
||||
///
|
||||
/// - PDF: Disables bottom-margin furniture stripping and prevents the layout
|
||||
/// model from treating `PageFooter`-classified regions as furniture.
|
||||
/// - DOCX: Includes document footers in text output.
|
||||
/// - RTF/ODT: Footers already included; this is a no-op when true.
|
||||
/// - HTML/EPUB: Keeps `<footer>` element content.
|
||||
///
|
||||
/// Default: `false` (footers are stripped or excluded).
|
||||
#[serde(default)]
|
||||
pub include_footers: bool,
|
||||
|
||||
/// Enable the heuristic cross-page repeating text detector.
|
||||
///
|
||||
/// When `true` (default), text that repeats verbatim across a supermajority
|
||||
/// of pages is classified as furniture and stripped. Disable this if brand
|
||||
/// names or repeated headings are being incorrectly removed by the heuristic.
|
||||
///
|
||||
/// Note: when a layout-detection model is active, the model may independently
|
||||
/// classify page-header / page-footer regions as furniture on a per-page basis.
|
||||
/// To preserve those regions, set `include_headers = true`, `include_footers = true`,
|
||||
/// or both, in addition to disabling this flag.
|
||||
///
|
||||
/// Primarily affects PDF extraction.
|
||||
///
|
||||
/// Default: `true`.
|
||||
#[serde(default = "default_true")]
|
||||
pub strip_repeating_text: bool,
|
||||
|
||||
/// Include watermark text in extraction output.
|
||||
///
|
||||
/// - PDF: Keeps watermark artifacts and arXiv identifiers.
|
||||
/// - Other formats: No effect currently.
|
||||
///
|
||||
/// Default: `false` (watermarks are stripped).
|
||||
#[serde(default)]
|
||||
pub include_watermarks: bool,
|
||||
}
|
||||
|
||||
impl Default for ContentFilterConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
include_headers: false,
|
||||
include_footers: false,
|
||||
strip_repeating_text: true,
|
||||
include_watermarks: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
58
crates/kreuzberg/src/core/config/email.rs
Normal file
58
crates/kreuzberg/src/core/config/email.rs
Normal file
@@ -0,0 +1,58 @@
|
||||
//! Email extraction configuration.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for email extraction.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
#[derive(Default)]
|
||||
pub struct EmailConfig {
|
||||
/// Windows codepage number to use when an MSG file contains no codepage property.
|
||||
/// Defaults to `None`, which falls back to windows-1252.
|
||||
///
|
||||
/// If an unrecognized or invalid codepage number is supplied (including 0),
|
||||
/// the behavior silently falls back to windows-1252 — the same as when the
|
||||
/// MSG file itself contains an unrecognized codepage. No error or warning is
|
||||
/// emitted. Users should verify output when supplying unusual values.
|
||||
///
|
||||
/// Common values:
|
||||
/// - 1250: Central European (Polish, Czech, Hungarian, etc.)
|
||||
/// - 1251: Cyrillic (Russian, Ukrainian, Bulgarian, etc.)
|
||||
/// - 1252: Western European (default)
|
||||
/// - 1253: Greek
|
||||
/// - 1254: Turkish
|
||||
/// - 1255: Hebrew
|
||||
/// - 1256: Arabic
|
||||
/// - 932: Japanese (Shift-JIS)
|
||||
/// - 936: Simplified Chinese (GBK)
|
||||
pub msg_fallback_codepage: Option<u32>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_email_config_default() {
|
||||
let config = EmailConfig::default();
|
||||
assert!(config.msg_fallback_codepage.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_email_config_serde_roundtrip() {
|
||||
let json = r#"{"msg_fallback_codepage": 1251}"#;
|
||||
let config: EmailConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.msg_fallback_codepage, Some(1251));
|
||||
|
||||
let serialized = serde_json::to_string(&config).unwrap();
|
||||
let roundtripped: EmailConfig = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(roundtripped.msg_fallback_codepage, Some(1251));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_email_config_serde_default_omitted() {
|
||||
let json = r#"{}"#;
|
||||
let config: EmailConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(config.msg_fallback_codepage.is_none());
|
||||
}
|
||||
}
|
||||
728
crates/kreuzberg/src/core/config/extraction/core.rs
Normal file
728
crates/kreuzberg/src/core/config/extraction/core.rs
Normal file
@@ -0,0 +1,728 @@
|
||||
//! Main extraction configuration struct.
|
||||
//!
|
||||
//! This module contains the main `ExtractionConfig` struct that aggregates all
|
||||
//! configuration options for the extraction process.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::super::acceleration::AccelerationConfig;
|
||||
use super::super::content_filter::ContentFilterConfig;
|
||||
use super::super::formats::OutputFormat;
|
||||
use super::super::ocr::OcrConfig;
|
||||
use super::super::page::PageConfig;
|
||||
use super::super::processing::{ChunkingConfig, PostProcessorConfig};
|
||||
use super::file_config::FileExtractionConfig;
|
||||
use super::types::{ImageExtractionConfig, LanguageDetectionConfig, TokenReductionOptions};
|
||||
|
||||
/// Main extraction configuration.
|
||||
///
|
||||
/// This struct contains all configuration options for the extraction process.
|
||||
/// It can be loaded from TOML, YAML, or JSON files, or created programmatically.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config::ExtractionConfig;
|
||||
///
|
||||
/// // Create with defaults
|
||||
/// let config = ExtractionConfig::default();
|
||||
///
|
||||
/// // Load from TOML file
|
||||
/// // let config = ExtractionConfig::from_toml_file("kreuzberg.toml")?;
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct ExtractionConfig {
|
||||
/// Enable caching of extraction results
|
||||
#[serde(default = "default_true")]
|
||||
pub use_cache: bool,
|
||||
|
||||
/// Enable quality post-processing
|
||||
#[serde(default = "default_true")]
|
||||
pub enable_quality_processing: bool,
|
||||
|
||||
/// OCR configuration (None = OCR disabled)
|
||||
#[serde(default)]
|
||||
pub ocr: Option<OcrConfig>,
|
||||
|
||||
/// Force OCR even for searchable PDFs
|
||||
#[serde(default)]
|
||||
pub force_ocr: bool,
|
||||
|
||||
/// Force OCR on specific pages only (1-indexed page numbers, must be >= 1).
|
||||
///
|
||||
/// When set, only the listed pages are OCR'd regardless of text layer quality.
|
||||
/// Unlisted pages use native text extraction. Ignored when `force_ocr` is `true`.
|
||||
/// Only applies to PDF documents. Duplicates are automatically deduplicated.
|
||||
/// An `ocr` config is recommended for backend/language selection; defaults are used if absent.
|
||||
#[serde(default)]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub force_ocr_pages: Option<Vec<u32>>,
|
||||
|
||||
/// Disable OCR entirely, even for images.
|
||||
///
|
||||
/// When `true`, OCR is skipped for all document types. Images return metadata
|
||||
/// only (dimensions, format, EXIF) without text extraction. PDFs use only
|
||||
/// native text extraction without OCR fallback.
|
||||
///
|
||||
/// Cannot be `true` simultaneously with `force_ocr`.
|
||||
///
|
||||
/// *Added in v4.7.0.*
|
||||
#[serde(default)]
|
||||
pub disable_ocr: bool,
|
||||
|
||||
/// Text chunking configuration (None = chunking disabled)
|
||||
#[serde(default)]
|
||||
pub chunking: Option<ChunkingConfig>,
|
||||
|
||||
/// Content filtering configuration (None = use extractor defaults).
|
||||
///
|
||||
/// Controls whether document "furniture" (headers, footers, watermarks,
|
||||
/// repeating text) is included in or stripped from extraction results.
|
||||
/// See [`ContentFilterConfig`] for per-field documentation.
|
||||
#[serde(default)]
|
||||
pub content_filter: Option<ContentFilterConfig>,
|
||||
|
||||
/// Image extraction configuration (None = no image extraction)
|
||||
#[serde(default)]
|
||||
pub images: Option<ImageExtractionConfig>,
|
||||
|
||||
/// PDF-specific options (None = use defaults)
|
||||
#[cfg(feature = "pdf")]
|
||||
#[serde(default)]
|
||||
pub pdf_options: Option<super::super::pdf::PdfConfig>,
|
||||
|
||||
/// Token reduction configuration (None = no token reduction)
|
||||
#[serde(default)]
|
||||
pub token_reduction: Option<TokenReductionOptions>,
|
||||
|
||||
/// Language detection configuration (None = no language detection)
|
||||
#[serde(default)]
|
||||
pub language_detection: Option<LanguageDetectionConfig>,
|
||||
|
||||
/// Page extraction configuration (None = no page tracking)
|
||||
#[serde(default)]
|
||||
pub pages: Option<PageConfig>,
|
||||
|
||||
/// Keyword extraction configuration (None = no keyword extraction)
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
#[serde(default)]
|
||||
pub keywords: Option<crate::keywords::KeywordConfig>,
|
||||
|
||||
/// Post-processor configuration (None = use defaults)
|
||||
#[serde(default)]
|
||||
pub postprocessor: Option<PostProcessorConfig>,
|
||||
|
||||
/// HTML to Markdown conversion options (None = use defaults)
|
||||
///
|
||||
/// Configure how HTML documents are converted to Markdown, including heading styles,
|
||||
/// list formatting, code block styles, and preprocessing options.
|
||||
#[cfg(feature = "html")]
|
||||
#[serde(default)]
|
||||
pub html_options: Option<html_to_markdown_rs::ConversionOptions>,
|
||||
|
||||
/// Styled HTML output configuration.
|
||||
///
|
||||
/// When set alongside `output_format = OutputFormat::Html`, the extraction
|
||||
/// pipeline uses [`StyledHtmlRenderer`](crate::rendering::StyledHtmlRenderer)
|
||||
/// which emits stable `kb-*` CSS class hooks on every structural element
|
||||
/// and optionally embeds theme CSS or user-supplied CSS in a `<style>` block.
|
||||
///
|
||||
/// When `None`, the existing plain comrak-based HTML renderer is used.
|
||||
#[cfg(feature = "html")]
|
||||
#[serde(default)]
|
||||
pub html_output: Option<crate::core::config::html_output::HtmlOutputConfig>,
|
||||
|
||||
/// Default per-file timeout in seconds for batch extraction.
|
||||
///
|
||||
/// When set, each file in a batch will be canceled after this duration
|
||||
/// unless overridden by [`FileExtractionConfig::timeout_secs`].
|
||||
///
|
||||
/// Defaults to `Some(60)` to prevent pathological files (e.g. deeply
|
||||
/// nested archives, documents with millions of cells) from running
|
||||
/// indefinitely and exhausting caller resources. Set to `None` to
|
||||
/// disable the timeout for trusted input or long-running workloads.
|
||||
#[serde(default = "default_extraction_timeout")]
|
||||
pub extraction_timeout_secs: Option<u64>,
|
||||
|
||||
/// Maximum concurrent extractions in batch operations (None = (num_cpus × 1.5).ceil()).
|
||||
///
|
||||
/// Limits parallelism to prevent resource exhaustion when processing
|
||||
/// large batches. Defaults to (num_cpus × 1.5).ceil() when not set.
|
||||
#[serde(default)]
|
||||
pub max_concurrent_extractions: Option<usize>,
|
||||
|
||||
/// Result structure format
|
||||
///
|
||||
/// Controls whether results are returned in unified format (default) with all
|
||||
/// content in the `content` field, or element-based format with semantic
|
||||
/// elements (for Unstructured-compatible output).
|
||||
#[serde(default)]
|
||||
pub result_format: crate::types::ResultFormat,
|
||||
|
||||
/// Security limits for archive extraction.
|
||||
///
|
||||
/// Controls maximum archive size, compression ratio, file count, and other
|
||||
/// security thresholds to prevent decompression bomb attacks. Also caps
|
||||
/// nesting depth, iteration count, entity / token length, total
|
||||
/// content size, and table cell count for every extraction path that
|
||||
/// ingests user-controlled bytes.
|
||||
/// When `None`, default limits are used.
|
||||
#[serde(default)]
|
||||
pub security_limits: Option<crate::extractors::security::SecurityLimits>,
|
||||
|
||||
/// Maximum uncompressed size in bytes for a single embedded file before
|
||||
/// recursive extraction is attempted (default: 50 MiB).
|
||||
///
|
||||
/// Applies to embedded objects inside OOXML containers (DOCX, PPTX) and
|
||||
/// to email attachments processed via recursive extraction. Files that
|
||||
/// exceed this limit are skipped with a `ProcessingWarning` rather than
|
||||
/// passed to the extraction pipeline, preventing a single oversized
|
||||
/// embedded object from consuming unbounded memory or time.
|
||||
///
|
||||
/// Set to `None` to disable the per-embedded-file cap (falls back to
|
||||
/// `security_limits.max_archive_size` as the only guard).
|
||||
#[serde(default = "default_max_embedded_file_bytes")]
|
||||
pub max_embedded_file_bytes: Option<u64>,
|
||||
|
||||
/// Content text format (default: Plain).
|
||||
///
|
||||
/// Controls the format of the extracted content:
|
||||
/// - `Plain`: Raw extracted text (default)
|
||||
/// - `Markdown`: Markdown formatted output
|
||||
/// - `Djot`: Djot markup format (requires djot feature)
|
||||
/// - `Html`: HTML formatted output
|
||||
///
|
||||
/// When set to a structured format, extraction results will include
|
||||
/// formatted output. The `formatted_content` field may be populated
|
||||
/// when format conversion is applied.
|
||||
#[serde(default)]
|
||||
pub output_format: OutputFormat,
|
||||
|
||||
/// Layout detection configuration (None = layout detection disabled).
|
||||
///
|
||||
/// When set, PDF pages and images are analyzed for document structure
|
||||
/// (headings, code, formulas, tables, figures, etc.) using RT-DETR models
|
||||
/// via ONNX Runtime. For PDFs, layout hints override paragraph classification
|
||||
/// in the markdown pipeline. For images, per-region OCR is performed with
|
||||
/// markdown formatting based on detected layout classes.
|
||||
/// Requires the `layout-detection` feature to run inference; the field is
|
||||
/// present whenever the `layout-types` feature is active (which includes
|
||||
/// `layout-detection` as well as the no-ORT target groups).
|
||||
#[cfg(feature = "layout-types")]
|
||||
#[serde(default)]
|
||||
pub layout: Option<super::super::layout::LayoutDetectionConfig>,
|
||||
|
||||
/// Run layout detection on the non-OCR PDF markdown path.
|
||||
///
|
||||
/// When `true` and `layout` is `Some(_)`, layout regions inform heading,
|
||||
/// table, list, and figure detection in the structure pipeline that would
|
||||
/// otherwise rely on font-clustering heuristics alone. Significantly
|
||||
/// improves SF1 (structural F1) at the cost of inference latency
|
||||
/// (~150-300ms/page CPU, ~20-50ms/page GPU). Default: `false`.
|
||||
/// Requires the `layout-detection` feature.
|
||||
#[serde(default)]
|
||||
pub use_layout_for_markdown: bool,
|
||||
|
||||
/// Enable structured document tree output.
|
||||
///
|
||||
/// When true, populates the `document` field on `ExtractionResult` with a
|
||||
/// hierarchical `DocumentStructure` containing heading-driven section nesting,
|
||||
/// table grids, content layer classification, and inline annotations.
|
||||
///
|
||||
/// Independent of `result_format` — can be combined with Unified or ElementBased.
|
||||
#[serde(default)]
|
||||
pub include_document_structure: bool,
|
||||
|
||||
/// Hardware acceleration configuration for ONNX Runtime models.
|
||||
///
|
||||
/// Controls execution provider selection for layout detection and embedding
|
||||
/// models. When `None`, uses platform defaults (CoreML on macOS, CUDA on
|
||||
/// Linux, CPU on Windows).
|
||||
#[serde(default)]
|
||||
pub acceleration: Option<AccelerationConfig>,
|
||||
|
||||
/// Cache namespace for tenant isolation.
|
||||
///
|
||||
/// When set, cache entries are stored under `{cache_dir}/{namespace}/`.
|
||||
/// Must be alphanumeric, hyphens, or underscores only (max 64 chars).
|
||||
/// Different namespaces have isolated cache spaces on the same filesystem.
|
||||
#[serde(default)]
|
||||
pub cache_namespace: Option<String>,
|
||||
|
||||
/// Per-request cache TTL in seconds.
|
||||
///
|
||||
/// Overrides the global `max_age_days` for this specific extraction.
|
||||
/// When `0`, caching is completely skipped (no read or write).
|
||||
/// When `None`, the global TTL applies.
|
||||
#[serde(default)]
|
||||
pub cache_ttl_secs: Option<u64>,
|
||||
|
||||
/// Email extraction configuration (None = use defaults).
|
||||
///
|
||||
/// Currently supports configuring the fallback codepage for MSG files
|
||||
/// that do not specify one. See [`crate::core::config::EmailConfig`] for details.
|
||||
#[serde(default)]
|
||||
pub email: Option<super::super::email::EmailConfig>,
|
||||
|
||||
/// Concurrency limits for constrained environments (None = use defaults).
|
||||
///
|
||||
/// Controls Rayon thread pool size, ONNX Runtime intra-op threads, and
|
||||
/// (when `max_concurrent_extractions` is unset) the batch concurrency
|
||||
/// semaphore. See [`crate::core::config::ConcurrencyConfig`] for details.
|
||||
#[serde(default)]
|
||||
pub concurrency: Option<super::super::concurrency::ConcurrencyConfig>,
|
||||
|
||||
/// Maximum recursion depth for archive extraction (default: 3).
|
||||
/// Set to 0 to disable recursive extraction (legacy behavior).
|
||||
#[serde(default = "default_archive_depth")]
|
||||
pub max_archive_depth: usize,
|
||||
|
||||
/// Tree-sitter language pack configuration (None = tree-sitter disabled).
|
||||
///
|
||||
/// When set, enables code file extraction using tree-sitter parsers.
|
||||
/// Controls grammar download behavior and code analysis options.
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
#[serde(default)]
|
||||
pub tree_sitter: Option<super::super::tree_sitter::TreeSitterConfig>,
|
||||
|
||||
/// Structured extraction via LLM (None = disabled).
|
||||
///
|
||||
/// When set, the extracted document content is sent to an LLM with the
|
||||
/// provided JSON schema. The structured response is stored in
|
||||
/// `ExtractionResult::structured_output`.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub structured_extraction: Option<super::super::llm::StructuredExtractionConfig>,
|
||||
|
||||
/// Cancellation token for this extraction (None = no external cancellation).
|
||||
///
|
||||
/// Pass a [`CancellationToken`] clone here and call [`CancellationToken::cancel`]
|
||||
/// from another thread / task to abort the extraction in progress. The extractor
|
||||
/// checks the token at safe checkpoints (before lock acquisition, between pages,
|
||||
/// between batch items) and returns [`KreuzbergError::Cancelled`] when set.
|
||||
///
|
||||
/// The field is excluded from serialization because `CancellationToken` is a
|
||||
/// runtime handle, not a configuration value.
|
||||
#[serde(skip)]
|
||||
pub cancel_token: Option<crate::cancellation::CancellationToken>,
|
||||
}
|
||||
|
||||
impl Default for ExtractionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
use_cache: true,
|
||||
enable_quality_processing: true,
|
||||
ocr: None,
|
||||
force_ocr: false,
|
||||
force_ocr_pages: None,
|
||||
disable_ocr: false,
|
||||
chunking: None,
|
||||
content_filter: None,
|
||||
images: None,
|
||||
#[cfg(feature = "pdf")]
|
||||
pdf_options: None,
|
||||
token_reduction: None,
|
||||
language_detection: None,
|
||||
pages: None,
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
keywords: None,
|
||||
postprocessor: None,
|
||||
#[cfg(feature = "html")]
|
||||
html_options: None,
|
||||
#[cfg(feature = "html")]
|
||||
html_output: None,
|
||||
extraction_timeout_secs: default_extraction_timeout(),
|
||||
max_concurrent_extractions: None,
|
||||
security_limits: None,
|
||||
max_embedded_file_bytes: default_max_embedded_file_bytes(),
|
||||
#[cfg(feature = "layout-types")]
|
||||
layout: None,
|
||||
use_layout_for_markdown: false,
|
||||
result_format: crate::types::ResultFormat::Unified,
|
||||
output_format: OutputFormat::Plain,
|
||||
include_document_structure: false,
|
||||
acceleration: None,
|
||||
cache_namespace: None,
|
||||
cache_ttl_secs: None,
|
||||
email: None,
|
||||
concurrency: None,
|
||||
max_archive_depth: default_archive_depth(),
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
tree_sitter: None,
|
||||
structured_extraction: None,
|
||||
cancel_token: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ExtractionConfig {
|
||||
/// Create a new `ExtractionConfig` by applying per-file overrides from a
|
||||
/// [`FileExtractionConfig`]. Fields that are `Some` in the override replace the
|
||||
/// corresponding field in `self`; `None` fields keep the original value.
|
||||
///
|
||||
/// Batch-level fields (`max_concurrent_extractions`, `use_cache`, `acceleration`,
|
||||
/// `security_limits`) are never affected by overrides.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::{ExtractionConfig, FileExtractionConfig};
|
||||
///
|
||||
/// let base = ExtractionConfig::default();
|
||||
/// let override_config = FileExtractionConfig {
|
||||
/// force_ocr: Some(true),
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// let resolved = base.with_file_overrides(&override_config);
|
||||
/// assert!(resolved.force_ocr);
|
||||
/// ```
|
||||
pub(crate) fn with_file_overrides(&self, overrides: &FileExtractionConfig) -> Self {
|
||||
// Destructure to ensure compile-time exhaustiveness: adding a field to
|
||||
// FileExtractionConfig without handling it here will produce a compile error.
|
||||
let FileExtractionConfig {
|
||||
ref enable_quality_processing,
|
||||
ref ocr,
|
||||
ref force_ocr,
|
||||
ref force_ocr_pages,
|
||||
ref disable_ocr,
|
||||
ref chunking,
|
||||
ref content_filter,
|
||||
ref images,
|
||||
#[cfg(feature = "pdf")]
|
||||
ref pdf_options,
|
||||
ref token_reduction,
|
||||
ref language_detection,
|
||||
ref pages,
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
ref keywords,
|
||||
ref postprocessor,
|
||||
#[cfg(feature = "html")]
|
||||
ref html_options,
|
||||
ref result_format,
|
||||
ref output_format,
|
||||
ref include_document_structure,
|
||||
#[cfg(feature = "layout-types")]
|
||||
ref layout,
|
||||
ref timeout_secs,
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
ref tree_sitter,
|
||||
ref structured_extraction,
|
||||
} = *overrides;
|
||||
|
||||
let mut config = self.clone();
|
||||
|
||||
if let Some(v) = enable_quality_processing {
|
||||
config.enable_quality_processing = *v;
|
||||
}
|
||||
if let Some(v) = ocr {
|
||||
config.ocr = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = force_ocr {
|
||||
config.force_ocr = *v;
|
||||
}
|
||||
if let Some(v) = force_ocr_pages {
|
||||
config.force_ocr_pages = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = disable_ocr {
|
||||
config.disable_ocr = *v;
|
||||
}
|
||||
if let Some(v) = chunking {
|
||||
config.chunking = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = content_filter {
|
||||
config.content_filter = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = images {
|
||||
config.images = Some(v.clone());
|
||||
}
|
||||
#[cfg(feature = "pdf")]
|
||||
if let Some(v) = pdf_options {
|
||||
config.pdf_options = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = token_reduction {
|
||||
config.token_reduction = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = language_detection {
|
||||
config.language_detection = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = pages {
|
||||
config.pages = Some(v.clone());
|
||||
}
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
if let Some(v) = keywords {
|
||||
config.keywords = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = postprocessor {
|
||||
config.postprocessor = Some(v.clone());
|
||||
}
|
||||
#[cfg(feature = "html")]
|
||||
if let Some(v) = html_options {
|
||||
config.html_options = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = result_format {
|
||||
config.result_format = *v;
|
||||
}
|
||||
if let Some(v) = output_format {
|
||||
config.output_format = v.clone();
|
||||
}
|
||||
if let Some(v) = include_document_structure {
|
||||
config.include_document_structure = *v;
|
||||
}
|
||||
#[cfg(feature = "layout-types")]
|
||||
if let Some(v) = layout {
|
||||
config.layout = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = timeout_secs {
|
||||
config.extraction_timeout_secs = Some(*v);
|
||||
}
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
if let Some(v) = tree_sitter {
|
||||
config.tree_sitter = Some(v.clone());
|
||||
}
|
||||
if let Some(v) = structured_extraction {
|
||||
config.structured_extraction = Some(v.clone());
|
||||
}
|
||||
|
||||
config
|
||||
}
|
||||
|
||||
/// Normalize configuration for implicit requirements.
|
||||
///
|
||||
/// Currently handles:
|
||||
/// - Auto-enabling `extract_pages` when `result_format` is `ElementBased`, because
|
||||
/// the element transformation requires per-page data to assign correct page numbers.
|
||||
/// Without this, all elements would incorrectly get `page_number=1`.
|
||||
/// - Auto-enabling `extract_pages` when chunking is configured, because the chunker
|
||||
/// needs page boundaries to assign correct page numbers to chunks.
|
||||
pub(crate) fn normalized(&self) -> std::borrow::Cow<'_, Self> {
|
||||
let needs_pages = |cfg: &Self| -> bool {
|
||||
match &cfg.pages {
|
||||
Some(page_config) => !page_config.extract_pages,
|
||||
None => true,
|
||||
}
|
||||
};
|
||||
|
||||
let needs_pages_for_elements =
|
||||
self.result_format == crate::types::ResultFormat::ElementBased && needs_pages(self);
|
||||
let needs_pages_for_chunking = self.chunking.is_some() && needs_pages(self);
|
||||
|
||||
if needs_pages_for_elements || needs_pages_for_chunking {
|
||||
let mut config = self.clone();
|
||||
let page_config = config.pages.get_or_insert_with(super::super::page::PageConfig::default);
|
||||
page_config.extract_pages = true;
|
||||
return std::borrow::Cow::Owned(config);
|
||||
}
|
||||
std::borrow::Cow::Borrowed(self)
|
||||
}
|
||||
|
||||
/// Validate the configuration, returning an error if any settings are invalid.
|
||||
///
|
||||
/// Checks:
|
||||
/// Returns the effective disable-OCR value, accounting for both the top-level
|
||||
/// `disable_ocr` flag and the `ocr.enabled` shorthand on [`OcrConfig`].
|
||||
///
|
||||
/// Setting `ocr.enabled = false` in configuration is treated as equivalent to
|
||||
/// `disable_ocr = true`. This method is the single source of truth for whether
|
||||
/// OCR should be skipped.
|
||||
pub(crate) fn effective_disable_ocr(&self) -> bool {
|
||||
self.disable_ocr || self.ocr.as_ref().is_some_and(|o| !o.enabled)
|
||||
}
|
||||
|
||||
/// Check if image processing is needed by examining OCR and image extraction settings.
|
||||
///
|
||||
/// Returns `true` if either OCR is enabled or image extraction is configured,
|
||||
/// indicating that image decompression and processing should occur.
|
||||
/// Returns `false` if both are disabled, allowing optimization to skip unnecessary
|
||||
/// image decompression for text-only extraction workflows.
|
||||
///
|
||||
/// # Optimization Impact
|
||||
/// For text-only extractions (no OCR, no image extraction), skipping image
|
||||
/// decompression can improve CPU utilization by 5-10% by avoiding wasteful
|
||||
/// image I/O and processing when results won't be used.
|
||||
pub fn needs_image_processing(&self) -> bool {
|
||||
let ocr_enabled = !self.effective_disable_ocr() && (self.ocr.is_some() || self.force_ocr);
|
||||
|
||||
let image_extraction_enabled = self.images.as_ref().map(|i| i.extract_images).unwrap_or(false);
|
||||
|
||||
#[cfg(feature = "layout-detection")]
|
||||
let layout_enabled = self.layout.is_some();
|
||||
#[cfg(not(feature = "layout-detection"))]
|
||||
let layout_enabled = false;
|
||||
|
||||
ocr_enabled || image_extraction_enabled || layout_enabled
|
||||
}
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_archive_depth() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
/// Default per-embedded-file cap: 50 MiB.
|
||||
///
|
||||
/// A single embedded object larger than this can consume significant memory
|
||||
/// when the recursive extractor materialises it. 50 MiB is generous for
|
||||
/// real-world embedded documents while still bounding worst-case allocation.
|
||||
fn default_max_embedded_file_bytes() -> Option<u64> {
|
||||
Some(50 * 1024 * 1024)
|
||||
}
|
||||
|
||||
/// Default extraction timeout: 60 seconds.
|
||||
///
|
||||
/// Pathological files (deeply nested archives, sheets with millions of cells,
|
||||
/// adversarial PDFs) can otherwise run indefinitely and exhaust caller
|
||||
/// resources. 60 s is generous for legitimate documents while bounding the
|
||||
/// worst-case cost of a single untrusted input.
|
||||
fn default_extraction_timeout() -> Option<u64> {
|
||||
Some(60)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::config::OcrConfig;
|
||||
|
||||
#[test]
|
||||
fn test_effective_disable_ocr_from_top_level_flag() {
|
||||
let config = ExtractionConfig {
|
||||
disable_ocr: true,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.effective_disable_ocr());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_effective_disable_ocr_from_ocr_enabled_false() {
|
||||
let config = ExtractionConfig {
|
||||
ocr: Some(OcrConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(
|
||||
config.effective_disable_ocr(),
|
||||
"ocr.enabled = false should be treated as disable_ocr = true"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_effective_disable_ocr_default_is_false() {
|
||||
let config = ExtractionConfig::default();
|
||||
assert!(!config.effective_disable_ocr());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_effective_disable_ocr_ocr_enabled_true_does_not_disable() {
|
||||
let config = ExtractionConfig {
|
||||
ocr: Some(OcrConfig {
|
||||
enabled: true,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(!config.effective_disable_ocr());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_enabled_false_deserialized_from_json() {
|
||||
let json = r#"{"ocr": {"enabled": false}}"#;
|
||||
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(
|
||||
config.effective_disable_ocr(),
|
||||
"JSON ocr.enabled=false should disable OCR"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_enabled_defaults_to_true() {
|
||||
let json = r#"{"ocr": {"backend": "tesseract"}}"#;
|
||||
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(!config.effective_disable_ocr(), "OCR should be enabled by default");
|
||||
}
|
||||
|
||||
#[cfg(feature = "layout-detection")]
|
||||
#[test]
|
||||
fn test_use_layout_for_markdown_defaults_to_false() {
|
||||
let config = ExtractionConfig::default();
|
||||
assert!(!config.use_layout_for_markdown);
|
||||
}
|
||||
|
||||
#[cfg(feature = "layout-detection")]
|
||||
#[test]
|
||||
fn test_use_layout_for_markdown_can_be_set_true() {
|
||||
let config = ExtractionConfig {
|
||||
use_layout_for_markdown: true,
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.use_layout_for_markdown);
|
||||
}
|
||||
|
||||
#[cfg(feature = "layout-detection")]
|
||||
#[test]
|
||||
fn test_use_layout_for_markdown_serde_round_trip() {
|
||||
let config = ExtractionConfig {
|
||||
use_layout_for_markdown: true,
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: ExtractionConfig = serde_json::from_str(&json).unwrap();
|
||||
assert!(deserialized.use_layout_for_markdown);
|
||||
}
|
||||
|
||||
#[cfg(feature = "layout-detection")]
|
||||
#[test]
|
||||
fn test_use_layout_for_markdown_serde_default_false() {
|
||||
// Field absent in JSON → should default to false.
|
||||
let json = r#"{}"#;
|
||||
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(!config.use_layout_for_markdown);
|
||||
}
|
||||
|
||||
// --- extraction_timeout_secs defaults ----------------------------------
|
||||
|
||||
#[test]
|
||||
fn test_default_extraction_timeout_is_sixty_seconds() {
|
||||
let config = ExtractionConfig::default();
|
||||
assert_eq!(
|
||||
config.extraction_timeout_secs,
|
||||
Some(60),
|
||||
"default timeout must be Some(60) to prevent unbounded extraction"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extraction_timeout_can_be_disabled_by_setting_none() {
|
||||
let config = ExtractionConfig {
|
||||
extraction_timeout_secs: None,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.extraction_timeout_secs, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extraction_timeout_serde_round_trip() {
|
||||
let config = ExtractionConfig {
|
||||
extraction_timeout_secs: Some(120),
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: ExtractionConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.extraction_timeout_secs, Some(120));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_extraction_timeout_serde_absent_field_defaults_to_sixty() {
|
||||
// When the JSON field is absent the serde default function must fire.
|
||||
let json = r#"{}"#;
|
||||
let config: ExtractionConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(
|
||||
config.extraction_timeout_secs,
|
||||
Some(60),
|
||||
"absent field must use default_extraction_timeout() -> Some(60)"
|
||||
);
|
||||
}
|
||||
}
|
||||
493
crates/kreuzberg/src/core/config/extraction/env.rs
Normal file
493
crates/kreuzberg/src/core/config/extraction/env.rs
Normal file
@@ -0,0 +1,493 @@
|
||||
//! Environment variable override support for extraction configuration.
|
||||
//!
|
||||
//! This module provides functionality to apply environment variable overrides
|
||||
//! to extraction configuration, allowing runtime configuration changes.
|
||||
|
||||
use crate::{KreuzbergError, Result};
|
||||
|
||||
use super::super::ocr::OcrConfig;
|
||||
use super::super::processing::ChunkingConfig;
|
||||
use super::core::ExtractionConfig;
|
||||
use super::types::TokenReductionOptions;
|
||||
|
||||
impl ExtractionConfig {
|
||||
/// Apply environment variable overrides to configuration.
|
||||
///
|
||||
/// Environment variables have the highest precedence and will override any values
|
||||
/// loaded from configuration files. This method supports the following environment variables:
|
||||
///
|
||||
/// - `KREUZBERG_OCR_LANGUAGE`: OCR language (ISO 639-1 or 639-3 code, e.g., "eng", "fra", "deu")
|
||||
/// - `KREUZBERG_OCR_BACKEND`: OCR backend ("tesseract", "easyocr", or "paddleocr")
|
||||
/// - `KREUZBERG_CHUNKING_MAX_CHARS`: Maximum characters per chunk (positive integer)
|
||||
/// - `KREUZBERG_CHUNKING_MAX_OVERLAP`: Maximum overlap between chunks (non-negative integer)
|
||||
/// - `KREUZBERG_CACHE_ENABLED`: Cache enabled flag ("true" or "false")
|
||||
/// - `KREUZBERG_TOKEN_REDUCTION_MODE`: Token reduction mode ("off", "light", "moderate", "aggressive", or "maximum")
|
||||
/// - `KREUZBERG_CHUNKING_TOKENIZER`: HuggingFace tokenizer model ID for token-based chunk sizing (requires `chunking-tokenizers` feature)
|
||||
/// - `KREUZBERG_DISABLE_OCR`: Disable OCR entirely ("true" or "false")
|
||||
/// - `KREUZBERG_LLM_MODEL`: LLM model for structured extraction (e.g., "openai/gpt-4o")
|
||||
/// - `KREUZBERG_LLM_API_KEY`: API key for the structured extraction LLM provider
|
||||
/// - `KREUZBERG_LLM_BASE_URL`: Custom base URL for the structured extraction LLM provider
|
||||
/// - `KREUZBERG_VLM_OCR_MODEL`: VLM model for vision-based OCR (e.g., "openai/gpt-4o")
|
||||
/// - `KREUZBERG_VLM_EMBEDDING_MODEL`: LLM model for embedding generation (e.g., "openai/text-embedding-3-small")
|
||||
/// - `KREUZBERG_EMBEDDING_PLUGIN_NAME`: Name of an in-process embedding backend registered via `plugins::register_embedding_backend`
|
||||
/// - `KREUZBERG_MSG_FALLBACK_CODEPAGE`: (deferred) Windows codepage for MSG PT_STRING8 fallback
|
||||
///
|
||||
/// # Behavior
|
||||
///
|
||||
/// - If an environment variable is set and valid, it overrides the current configuration value
|
||||
/// - If a required parent config is `None` (e.g., `self.ocr` is None), it's created with defaults before applying the override
|
||||
/// - Invalid values return a `KreuzbergError::Validation` with helpful error messages
|
||||
/// - Missing or unset environment variables are silently ignored
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// # use kreuzberg::core::config::ExtractionConfig;
|
||||
/// # fn example() -> kreuzberg::Result<()> {
|
||||
/// let mut config = ExtractionConfig::from_file("config.toml")?;
|
||||
/// // Set KREUZBERG_OCR_LANGUAGE=fra before calling
|
||||
/// config.apply_env_overrides()?; // OCR language is now "fra"
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if:
|
||||
/// - An environment variable contains an invalid value
|
||||
/// - A number cannot be parsed as the expected type
|
||||
/// - A boolean is not "true" or "false"
|
||||
pub fn apply_env_overrides(&mut self) -> Result<()> {
|
||||
use crate::core::config_validation::{
|
||||
validate_chunking_params, validate_language_code, validate_ocr_backend, validate_token_reduction_level,
|
||||
};
|
||||
|
||||
// KREUZBERG_OCR_LANGUAGE override
|
||||
if let Ok(lang) = std::env::var("KREUZBERG_OCR_LANGUAGE") {
|
||||
validate_language_code(&lang)?;
|
||||
if self.ocr.is_none() {
|
||||
self.ocr = Some(OcrConfig::default());
|
||||
}
|
||||
if let Some(ref mut ocr) = self.ocr {
|
||||
ocr.language = lang;
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_OCR_BACKEND override
|
||||
if let Ok(backend) = std::env::var("KREUZBERG_OCR_BACKEND") {
|
||||
validate_ocr_backend(&backend)?;
|
||||
if self.ocr.is_none() {
|
||||
self.ocr = Some(OcrConfig::default());
|
||||
}
|
||||
if let Some(ref mut ocr) = self.ocr {
|
||||
ocr.backend = backend;
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_CHUNKING_MAX_CHARS override
|
||||
if let Ok(max_chars_str) = std::env::var("KREUZBERG_CHUNKING_MAX_CHARS") {
|
||||
let max_chars: usize = max_chars_str.parse().map_err(|_| KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid value for KREUZBERG_CHUNKING_MAX_CHARS: '{}'. Must be a positive integer.",
|
||||
max_chars_str
|
||||
),
|
||||
source: None,
|
||||
})?;
|
||||
|
||||
if max_chars == 0 {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_CHUNKING_MAX_CHARS must be greater than 0".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
if self.chunking.is_none() {
|
||||
self.chunking = Some(ChunkingConfig::default());
|
||||
}
|
||||
|
||||
if let Some(ref mut chunking) = self.chunking {
|
||||
// Validate against current overlap before updating
|
||||
validate_chunking_params(max_chars, chunking.overlap)?;
|
||||
chunking.max_characters = max_chars;
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_CHUNKING_MAX_OVERLAP override
|
||||
if let Ok(max_overlap_str) = std::env::var("KREUZBERG_CHUNKING_MAX_OVERLAP") {
|
||||
let max_overlap: usize = max_overlap_str.parse().map_err(|_| KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid value for KREUZBERG_CHUNKING_MAX_OVERLAP: '{}'. Must be a non-negative integer.",
|
||||
max_overlap_str
|
||||
),
|
||||
source: None,
|
||||
})?;
|
||||
|
||||
if self.chunking.is_none() {
|
||||
self.chunking = Some(ChunkingConfig::default());
|
||||
}
|
||||
|
||||
if let Some(ref mut chunking) = self.chunking {
|
||||
// Validate against current max_characters before updating
|
||||
validate_chunking_params(chunking.max_characters, max_overlap)?;
|
||||
chunking.overlap = max_overlap;
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_CACHE_ENABLED override
|
||||
if let Ok(cache_str) = std::env::var("KREUZBERG_CACHE_ENABLED") {
|
||||
let cache_enabled = match cache_str.to_lowercase().as_str() {
|
||||
"true" => true,
|
||||
"false" => false,
|
||||
_ => {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid value for KREUZBERG_CACHE_ENABLED: '{}'. Must be 'true' or 'false'.",
|
||||
cache_str
|
||||
),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
};
|
||||
self.use_cache = cache_enabled;
|
||||
}
|
||||
|
||||
// KREUZBERG_TOKEN_REDUCTION_MODE override
|
||||
if let Ok(mode) = std::env::var("KREUZBERG_TOKEN_REDUCTION_MODE") {
|
||||
validate_token_reduction_level(&mode)?;
|
||||
if self.token_reduction.is_none() {
|
||||
self.token_reduction = Some(TokenReductionOptions {
|
||||
mode: "off".to_string(),
|
||||
preserve_important_words: true,
|
||||
});
|
||||
}
|
||||
if let Some(ref mut token_reduction) = self.token_reduction {
|
||||
token_reduction.mode = mode;
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_OUTPUT_FORMAT override
|
||||
if let Ok(val) = std::env::var("KREUZBERG_OUTPUT_FORMAT") {
|
||||
self.output_format = val.parse().map_err(|e: String| KreuzbergError::Validation {
|
||||
message: format!("Invalid value for KREUZBERG_OUTPUT_FORMAT: {}", e),
|
||||
source: None,
|
||||
})?;
|
||||
}
|
||||
|
||||
// KREUZBERG_CHUNKING_TOKENIZER override
|
||||
#[cfg(feature = "chunking-tokenizers")]
|
||||
if let Ok(model) = std::env::var("KREUZBERG_CHUNKING_TOKENIZER") {
|
||||
if model.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_CHUNKING_TOKENIZER must not be empty".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
if self.chunking.is_none() {
|
||||
self.chunking = Some(ChunkingConfig::default());
|
||||
}
|
||||
|
||||
if let Some(ref mut chunking) = self.chunking {
|
||||
chunking.sizing = crate::core::config::processing::ChunkSizing::Tokenizer { model, cache_dir: None };
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_LAYOUT_PRESET override (backward compat: enables layout detection).
|
||||
// Only one model (RT-DETR) exists, so the specific preset value is ignored.
|
||||
#[cfg(feature = "layout-detection")]
|
||||
if let Ok(preset) = std::env::var("KREUZBERG_LAYOUT_PRESET") {
|
||||
let lower = preset.to_lowercase();
|
||||
if !["fast", "accurate", "yolo", "rtdetr", "rt-detr"].contains(&lower.as_str()) {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid value for KREUZBERG_LAYOUT_PRESET: '{}'. Valid presets: fast, accurate",
|
||||
preset
|
||||
),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if self.layout.is_none() {
|
||||
self.layout = Some(super::super::layout::LayoutDetectionConfig::default());
|
||||
}
|
||||
// preset value is accepted but ignored -- only RT-DETR is available
|
||||
let _ = lower;
|
||||
}
|
||||
|
||||
// KREUZBERG_DISABLE_OCR override
|
||||
if let Ok(val) = std::env::var("KREUZBERG_DISABLE_OCR") {
|
||||
self.disable_ocr = match val.to_lowercase().as_str() {
|
||||
"true" | "1" => true,
|
||||
"false" | "0" => false,
|
||||
_ => {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid value for KREUZBERG_DISABLE_OCR: '{}'. Must be 'true' or 'false'.",
|
||||
val
|
||||
),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
// KREUZBERG_LLM_MODEL override
|
||||
if let Ok(value) = std::env::var("KREUZBERG_LLM_MODEL") {
|
||||
if value.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_LLM_MODEL must not be empty".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if self.structured_extraction.is_none() {
|
||||
self.structured_extraction = Some(super::super::llm::StructuredExtractionConfig {
|
||||
schema: serde_json::Value::Object(Default::default()),
|
||||
schema_name: "extraction".to_string(),
|
||||
schema_description: None,
|
||||
strict: false,
|
||||
prompt: None,
|
||||
llm: super::super::llm::LlmConfig {
|
||||
model: value,
|
||||
api_key: None,
|
||||
base_url: None,
|
||||
timeout_secs: None,
|
||||
max_retries: None,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
},
|
||||
});
|
||||
} else if let Some(ref mut config) = self.structured_extraction {
|
||||
config.llm.model = value;
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_LLM_API_KEY override
|
||||
if let Ok(value) = std::env::var("KREUZBERG_LLM_API_KEY") {
|
||||
if value.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_LLM_API_KEY must not be empty".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if self.structured_extraction.is_none() {
|
||||
self.structured_extraction = Some(super::super::llm::StructuredExtractionConfig {
|
||||
schema: serde_json::Value::Object(Default::default()),
|
||||
schema_name: "extraction".to_string(),
|
||||
schema_description: None,
|
||||
strict: false,
|
||||
prompt: None,
|
||||
llm: super::super::llm::LlmConfig {
|
||||
model: String::new(),
|
||||
api_key: Some(value),
|
||||
base_url: None,
|
||||
timeout_secs: None,
|
||||
max_retries: None,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
},
|
||||
});
|
||||
} else if let Some(ref mut config) = self.structured_extraction {
|
||||
config.llm.api_key = Some(value);
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_LLM_BASE_URL override
|
||||
if let Ok(value) = std::env::var("KREUZBERG_LLM_BASE_URL") {
|
||||
if value.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_LLM_BASE_URL must not be empty".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if self.structured_extraction.is_none() {
|
||||
self.structured_extraction = Some(super::super::llm::StructuredExtractionConfig {
|
||||
schema: serde_json::Value::Object(Default::default()),
|
||||
schema_name: "extraction".to_string(),
|
||||
schema_description: None,
|
||||
strict: false,
|
||||
prompt: None,
|
||||
llm: super::super::llm::LlmConfig {
|
||||
model: String::new(),
|
||||
api_key: None,
|
||||
base_url: Some(value),
|
||||
timeout_secs: None,
|
||||
max_retries: None,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
},
|
||||
});
|
||||
} else if let Some(ref mut config) = self.structured_extraction {
|
||||
config.llm.base_url = Some(value);
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_VLM_OCR_MODEL override
|
||||
if let Ok(value) = std::env::var("KREUZBERG_VLM_OCR_MODEL") {
|
||||
if value.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_VLM_OCR_MODEL must not be empty".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if self.ocr.is_none() {
|
||||
self.ocr = Some(OcrConfig::default());
|
||||
}
|
||||
if let Some(ref mut ocr) = self.ocr {
|
||||
if ocr.vlm_config.is_none() {
|
||||
ocr.vlm_config = Some(super::super::llm::LlmConfig {
|
||||
model: value,
|
||||
api_key: None,
|
||||
base_url: None,
|
||||
timeout_secs: None,
|
||||
max_retries: None,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
});
|
||||
} else if let Some(ref mut vlm) = ocr.vlm_config {
|
||||
vlm.model = value;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_VLM_EMBEDDING_MODEL override
|
||||
if let Ok(value) = std::env::var("KREUZBERG_VLM_EMBEDDING_MODEL") {
|
||||
if value.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_VLM_EMBEDDING_MODEL must not be empty".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if self.chunking.is_none() {
|
||||
self.chunking = Some(ChunkingConfig::default());
|
||||
}
|
||||
if let Some(ref mut chunking) = self.chunking {
|
||||
chunking.embedding = Some(super::super::processing::EmbeddingConfig {
|
||||
model: super::super::processing::EmbeddingModelType::Llm {
|
||||
llm: super::super::llm::LlmConfig {
|
||||
model: value,
|
||||
api_key: None,
|
||||
base_url: None,
|
||||
timeout_secs: None,
|
||||
max_retries: None,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
},
|
||||
},
|
||||
..super::super::processing::EmbeddingConfig::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
// KREUZBERG_EMBEDDING_PLUGIN_NAME override.
|
||||
// Selects an already-registered in-process embedding backend by name.
|
||||
// Setting this together with KREUZBERG_VLM_EMBEDDING_MODEL is rejected — they
|
||||
// configure mutually-exclusive embedding sources and the result of "both set"
|
||||
// would otherwise depend on source order in this function. Pick one.
|
||||
let plugin_name = std::env::var("KREUZBERG_EMBEDDING_PLUGIN_NAME").ok();
|
||||
if plugin_name.is_some() && std::env::var("KREUZBERG_VLM_EMBEDDING_MODEL").is_ok() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message:
|
||||
"KREUZBERG_EMBEDDING_PLUGIN_NAME and KREUZBERG_VLM_EMBEDDING_MODEL are mutually exclusive — set one or the other, not both."
|
||||
.to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if let Some(value) = plugin_name {
|
||||
if value.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "KREUZBERG_EMBEDDING_PLUGIN_NAME must not be empty".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
if self.chunking.is_none() {
|
||||
self.chunking = Some(ChunkingConfig::default());
|
||||
}
|
||||
if let Some(ref mut chunking) = self.chunking {
|
||||
chunking.embedding = Some(super::super::processing::EmbeddingConfig {
|
||||
model: super::super::processing::EmbeddingModelType::Plugin { name: value },
|
||||
..super::super::processing::EmbeddingConfig::default()
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unsafe_code)] // env mutation in 2024 edition is unsafe; tests serialize via ENV_LOCK
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::config::processing::EmbeddingModelType;
|
||||
|
||||
/// Lock guarding env-var mutation across tests in this module — `std::env::set_var`
|
||||
/// is process-global and concurrent tests would race.
|
||||
static ENV_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
|
||||
|
||||
fn clear_embedding_env() {
|
||||
// SAFETY: callers hold ENV_LOCK so no other thread is reading these vars.
|
||||
unsafe {
|
||||
std::env::remove_var("KREUZBERG_EMBEDDING_PLUGIN_NAME");
|
||||
std::env::remove_var("KREUZBERG_VLM_EMBEDDING_MODEL");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedding_plugin_and_vlm_embedding_model_are_mutually_exclusive() {
|
||||
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
|
||||
clear_embedding_env();
|
||||
// SAFETY: see clear_embedding_env.
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_EMBEDDING_PLUGIN_NAME", "my-embedder");
|
||||
std::env::set_var("KREUZBERG_VLM_EMBEDDING_MODEL", "openai/text-embedding-3-small");
|
||||
}
|
||||
let mut config = ExtractionConfig::default();
|
||||
let err = config
|
||||
.apply_env_overrides()
|
||||
.expect_err("should reject conflicting embedding env vars");
|
||||
assert!(
|
||||
matches!(err, KreuzbergError::Validation { .. }),
|
||||
"expected Validation, got {err:?}"
|
||||
);
|
||||
let msg = err.to_string();
|
||||
assert!(msg.contains("mutually exclusive"), "message: {msg}");
|
||||
clear_embedding_env();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_embedding_plugin_name_rejected() {
|
||||
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
|
||||
clear_embedding_env();
|
||||
// SAFETY: see clear_embedding_env.
|
||||
unsafe { std::env::set_var("KREUZBERG_EMBEDDING_PLUGIN_NAME", "") };
|
||||
let mut config = ExtractionConfig::default();
|
||||
let err = config
|
||||
.apply_env_overrides()
|
||||
.expect_err("should reject empty plugin name");
|
||||
assert!(
|
||||
matches!(err, KreuzbergError::Validation { .. }),
|
||||
"expected Validation, got {err:?}"
|
||||
);
|
||||
clear_embedding_env();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn embedding_plugin_env_sets_chunking_embedding_to_plugin_variant() {
|
||||
let _guard = ENV_LOCK.lock().unwrap_or_else(|p| p.into_inner());
|
||||
clear_embedding_env();
|
||||
// SAFETY: see clear_embedding_env.
|
||||
unsafe { std::env::set_var("KREUZBERG_EMBEDDING_PLUGIN_NAME", "my-embedder") };
|
||||
let mut config = ExtractionConfig::default();
|
||||
config
|
||||
.apply_env_overrides()
|
||||
.expect("should succeed with only plugin name set");
|
||||
let chunking = config.chunking.as_ref().expect("chunking should be created");
|
||||
let embedding = chunking.embedding.as_ref().expect("embedding should be set");
|
||||
match &embedding.model {
|
||||
EmbeddingModelType::Plugin { name } => {
|
||||
assert_eq!(name, "my-embedder");
|
||||
}
|
||||
other => panic!("expected Plugin variant, got {other:?}"),
|
||||
}
|
||||
clear_embedding_env();
|
||||
}
|
||||
}
|
||||
148
crates/kreuzberg/src/core/config/extraction/file_config.rs
Normal file
148
crates/kreuzberg/src/core/config/extraction/file_config.rs
Normal file
@@ -0,0 +1,148 @@
|
||||
//! Per-file extraction configuration overrides for batch processing.
|
||||
//!
|
||||
//! This module contains [`FileExtractionConfig`], a subset of [`super::ExtractionConfig`]
|
||||
//! where every field is optional. When used with batch extraction functions, each file
|
||||
//! can specify overrides that are merged with the batch-level default config.
|
||||
//!
|
||||
//! Fields that are batch-level concerns (concurrency, caching, acceleration, security)
|
||||
//! are intentionally excluded and can only be set on the batch-level [`super::ExtractionConfig`].
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::super::formats::OutputFormat;
|
||||
use super::super::ocr::OcrConfig;
|
||||
use super::super::page::PageConfig;
|
||||
use super::super::processing::{ChunkingConfig, PostProcessorConfig};
|
||||
use super::types::{ImageExtractionConfig, LanguageDetectionConfig, TokenReductionOptions};
|
||||
|
||||
/// Per-file extraction configuration overrides for batch processing.
|
||||
///
|
||||
/// All fields are `Option<T>` — `None` means "use the batch-level default."
|
||||
/// This type is used with [`crate::batch_extract_files`] and
|
||||
/// [`crate::batch_extract_bytes`] to allow heterogeneous
|
||||
/// extraction settings within a single batch.
|
||||
///
|
||||
/// # Excluded Fields
|
||||
///
|
||||
/// The following [`super::ExtractionConfig`] fields are batch-level only and
|
||||
/// cannot be overridden per file:
|
||||
/// - `max_concurrent_extractions` — controls batch parallelism
|
||||
/// - `use_cache` — global caching policy
|
||||
/// - `acceleration` — shared ONNX execution provider
|
||||
/// - `security_limits` — global archive security policy
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::FileExtractionConfig;
|
||||
///
|
||||
/// // Override just OCR forcing for a specific file
|
||||
/// let config = FileExtractionConfig {
|
||||
/// force_ocr: Some(true),
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct FileExtractionConfig {
|
||||
/// Override quality post-processing for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub enable_quality_processing: Option<bool>,
|
||||
|
||||
/// Override OCR configuration for this file (None in the Option = use batch default).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub ocr: Option<OcrConfig>,
|
||||
|
||||
/// Override force OCR for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub force_ocr: Option<bool>,
|
||||
|
||||
/// Override force OCR pages for this file (1-indexed page numbers).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub force_ocr_pages: Option<Vec<u32>>,
|
||||
|
||||
/// Override disable OCR for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub disable_ocr: Option<bool>,
|
||||
|
||||
/// Override chunking configuration for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chunking: Option<ChunkingConfig>,
|
||||
|
||||
/// Override content filtering configuration for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub content_filter: Option<super::super::content_filter::ContentFilterConfig>,
|
||||
|
||||
/// Override image extraction configuration for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub images: Option<ImageExtractionConfig>,
|
||||
|
||||
/// Override PDF options for this file.
|
||||
#[cfg(feature = "pdf")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pdf_options: Option<super::super::pdf::PdfConfig>,
|
||||
|
||||
/// Override token reduction for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub token_reduction: Option<TokenReductionOptions>,
|
||||
|
||||
/// Override language detection for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub language_detection: Option<LanguageDetectionConfig>,
|
||||
|
||||
/// Override page extraction for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub pages: Option<PageConfig>,
|
||||
|
||||
/// Override keyword extraction for this file.
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub keywords: Option<crate::keywords::KeywordConfig>,
|
||||
|
||||
/// Override post-processor for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub postprocessor: Option<PostProcessorConfig>,
|
||||
|
||||
/// Override HTML conversion options for this file.
|
||||
#[cfg(feature = "html")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub html_options: Option<html_to_markdown_rs::ConversionOptions>,
|
||||
|
||||
/// Override result format for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub result_format: Option<crate::types::ResultFormat>,
|
||||
|
||||
/// Override output content format for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub output_format: Option<OutputFormat>,
|
||||
|
||||
/// Override document structure output for this file.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub include_document_structure: Option<bool>,
|
||||
|
||||
/// Override layout detection for this file.
|
||||
#[cfg(feature = "layout-types")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub layout: Option<super::super::layout::LayoutDetectionConfig>,
|
||||
|
||||
/// Override per-file extraction timeout in seconds.
|
||||
///
|
||||
/// When set, the extraction for this file will be canceled after the
|
||||
/// specified duration. A timed-out file produces an error result without
|
||||
/// affecting other files in the batch.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub timeout_secs: Option<u64>,
|
||||
|
||||
/// Override tree-sitter configuration for this file.
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tree_sitter: Option<super::super::tree_sitter::TreeSitterConfig>,
|
||||
|
||||
/// Override structured extraction configuration for this file.
|
||||
///
|
||||
/// When set, enables LLM-based structured extraction with a JSON schema
|
||||
/// for this specific file. The extracted content is sent to a VLM/LLM
|
||||
/// and the response is parsed according to the provided schema.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub structured_extraction: Option<super::super::llm::StructuredExtractionConfig>,
|
||||
}
|
||||
87
crates/kreuzberg/src/core/config/extraction/loaders.rs
Normal file
87
crates/kreuzberg/src/core/config/extraction/loaders.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
//! Configuration file loading.
|
||||
//!
|
||||
//! This module provides methods for loading extraction configuration from
|
||||
//! TOML, YAML, and JSON files.
|
||||
|
||||
use crate::{KreuzbergError, Result};
|
||||
use std::path::Path;
|
||||
|
||||
use super::core::ExtractionConfig;
|
||||
|
||||
impl ExtractionConfig {
|
||||
/// Load configuration from a TOML file.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if file doesn't exist or is invalid TOML.
|
||||
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
|
||||
toml::from_str(&content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
/// Load configuration from a YAML file.
|
||||
pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
|
||||
serde_yaml_ng::from_str(&content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
/// Load configuration from a JSON file.
|
||||
pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
|
||||
serde_json::from_str(&content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
/// Load configuration from a file, auto-detecting format by extension.
|
||||
///
|
||||
/// Supported formats: `.toml`, `.yaml`, `.yml`, `.json`.
|
||||
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
let path = path.as_ref();
|
||||
let extension = path.extension().and_then(|ext| ext.to_str()).ok_or_else(|| {
|
||||
KreuzbergError::validation(format!(
|
||||
"Cannot determine file format: no extension found in {}",
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
match extension.to_lowercase().as_str() {
|
||||
"toml" => Self::from_toml_file(path),
|
||||
"yaml" | "yml" => Self::from_yaml_file(path),
|
||||
"json" => Self::from_json_file(path),
|
||||
other => Err(KreuzbergError::validation(format!(
|
||||
"Unsupported config file format: .{}. Supported formats: .toml, .yaml, .json",
|
||||
other
|
||||
))),
|
||||
}
|
||||
}
|
||||
|
||||
/// Discover configuration file in parent directories.
|
||||
///
|
||||
/// Searches for `kreuzberg.toml` in current directory and parent directories.
|
||||
pub fn discover() -> Result<Option<Self>> {
|
||||
let mut current = std::env::current_dir().map_err(KreuzbergError::Io)?;
|
||||
|
||||
loop {
|
||||
let kreuzberg_toml = current.join("kreuzberg.toml");
|
||||
if kreuzberg_toml.exists() {
|
||||
return Ok(Some(Self::from_toml_file(kreuzberg_toml)?));
|
||||
}
|
||||
|
||||
if let Some(parent) = current.parent() {
|
||||
current = parent.to_path_buf();
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
46
crates/kreuzberg/src/core/config/extraction/mod.rs
Normal file
46
crates/kreuzberg/src/core/config/extraction/mod.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! Main extraction configuration and environment variable handling.
|
||||
//!
|
||||
//! This module contains the main `ExtractionConfig` struct and related utilities
|
||||
//! for loading configuration from files and applying environment variable overrides.
|
||||
//!
|
||||
//! The module is organized into focused submodules:
|
||||
//! - `types`: Feature-specific configuration types (image, token reduction, language detection)
|
||||
//! - `core`: Main ExtractionConfig struct and implementation
|
||||
//! - `env`: Environment variable override support
|
||||
//! - `loaders`: Configuration file loading with caching
|
||||
|
||||
mod core;
|
||||
mod env;
|
||||
mod file_config;
|
||||
mod loaders;
|
||||
mod types;
|
||||
|
||||
// Re-export all public types for backward compatibility
|
||||
pub use self::core::ExtractionConfig;
|
||||
pub use self::file_config::FileExtractionConfig;
|
||||
pub use self::types::{
|
||||
BatchBytesItem, BatchFileItem, ImageExtractionConfig, LanguageDetectionConfig, TokenReductionOptions,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::config::ocr::OcrConfig;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ExtractionConfig::default();
|
||||
assert!(config.use_cache);
|
||||
assert!(config.enable_quality_processing);
|
||||
assert!(config.ocr.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_needs_image_processing() {
|
||||
let mut config = ExtractionConfig::default();
|
||||
assert!(!config.needs_image_processing());
|
||||
|
||||
config.ocr = Some(OcrConfig::default());
|
||||
assert!(config.needs_image_processing());
|
||||
}
|
||||
}
|
||||
345
crates/kreuzberg/src/core/config/extraction/types.rs
Normal file
345
crates/kreuzberg/src/core/config/extraction/types.rs
Normal file
@@ -0,0 +1,345 @@
|
||||
//! Feature-specific configuration types for extraction.
|
||||
//!
|
||||
//! This module contains configuration structs for specific extraction features:
|
||||
//! - Image extraction and processing
|
||||
//! - Token reduction
|
||||
//! - Language detection
|
||||
//! - Batch extraction items
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Batch item for byte array extraction.
|
||||
///
|
||||
/// Used with [`crate::batch_extract_bytes`] and [`crate::batch_extract_bytes_sync`]
|
||||
/// to represent a single item in a batch extraction job.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BatchBytesItem {
|
||||
/// The content bytes to extract from
|
||||
pub content: Vec<u8>,
|
||||
|
||||
/// MIME type of the content (e.g., "application/pdf", "text/html")
|
||||
pub mime_type: String,
|
||||
|
||||
/// Per-item configuration overrides (None uses batch-level defaults)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub config: Option<super::FileExtractionConfig>,
|
||||
}
|
||||
|
||||
/// Batch item for file extraction.
|
||||
///
|
||||
/// Used with [`crate::batch_extract_files`] and [`crate::batch_extract_files_sync`]
|
||||
/// to represent a single file in a batch extraction job.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BatchFileItem {
|
||||
/// Path to the file to extract from
|
||||
pub path: PathBuf,
|
||||
|
||||
/// Per-file configuration overrides (None uses batch-level defaults)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub config: Option<super::FileExtractionConfig>,
|
||||
}
|
||||
|
||||
/// Image extraction configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ImageExtractionConfig {
|
||||
/// Extract images from documents
|
||||
#[serde(default = "default_true")]
|
||||
pub extract_images: bool,
|
||||
|
||||
/// Target DPI for image normalization
|
||||
#[serde(default = "default_target_dpi")]
|
||||
pub target_dpi: i32,
|
||||
|
||||
/// Maximum dimension for images (width or height)
|
||||
#[serde(default = "default_max_dimension")]
|
||||
pub max_image_dimension: i32,
|
||||
|
||||
/// Whether to inject image reference placeholders into markdown output.
|
||||
/// When `true` (default), image references like ``
|
||||
/// are appended to the markdown. Set to `false` to extract images as data
|
||||
/// without polluting the markdown output.
|
||||
#[serde(default = "default_true")]
|
||||
pub inject_placeholders: bool,
|
||||
|
||||
/// Automatically adjust DPI based on image content
|
||||
#[serde(default = "default_true")]
|
||||
pub auto_adjust_dpi: bool,
|
||||
|
||||
/// Minimum DPI threshold
|
||||
#[serde(default = "default_min_dpi")]
|
||||
pub min_dpi: i32,
|
||||
|
||||
/// Maximum DPI threshold
|
||||
#[serde(default = "default_max_dpi")]
|
||||
pub max_dpi: i32,
|
||||
|
||||
/// Maximum number of image objects to extract per PDF page.
|
||||
///
|
||||
/// Some PDFs (e.g. technical diagrams stored as thousands of raster fragments)
|
||||
/// can trigger extremely long or indefinite extraction times when every image
|
||||
/// object on a dense page is decoded individually via the PDF extractor. Setting this
|
||||
/// limit causes kreuzberg to stop collecting individual images once the count
|
||||
/// per page reaches the cap and emit a warning instead.
|
||||
///
|
||||
/// `None` (default) means no limit — all images are extracted.
|
||||
#[serde(default)]
|
||||
pub max_images_per_page: Option<u32>,
|
||||
|
||||
/// When `true` (default), extracted images are classified by kind and grouped
|
||||
/// into clusters where they appear to belong to one figure.
|
||||
#[serde(default = "default_true")]
|
||||
pub classify: bool,
|
||||
|
||||
/// When `true`, full-page renders produced during OCR preprocessing are captured
|
||||
/// and returned as `ImageKind::PageRaster` entries in `ExtractionResult.images`.
|
||||
///
|
||||
/// **PDF + OCR only.** No rasters are captured for non-PDF inputs or when the
|
||||
/// document-level OCR bypass is active (whole-document backend). When OCR is
|
||||
/// enabled and this flag is set but the active backend skips per-page rendering,
|
||||
/// a `ProcessingWarning` is emitted in `ExtractionResult.processing_warnings`.
|
||||
///
|
||||
/// Defaults to `false`. Enable when downstream consumers need page thumbnails
|
||||
/// (e.g. citation previews, visual grounding).
|
||||
#[serde(default)]
|
||||
pub include_page_rasters: bool,
|
||||
|
||||
/// Run OCR on extracted images and include the recognized text in the document content.
|
||||
///
|
||||
/// When `true` (default) and `ExtractionConfig.ocr` is configured, extracted images
|
||||
/// are processed with the configured OCR backend. Set to `false` to extract images
|
||||
/// without OCR processing, even when OCR is enabled.
|
||||
#[serde(default = "default_true")]
|
||||
pub run_ocr_on_images: bool,
|
||||
|
||||
/// When `true`, image OCR results are rendered as plain text without the
|
||||
/// `` markdown placeholder. Only takes effect when `run_ocr_on_images`
|
||||
/// is also `true`.
|
||||
#[serde(default)]
|
||||
pub ocr_text_only: bool,
|
||||
|
||||
/// When `true` and `ocr_text_only` is `false`, append the OCR text after
|
||||
/// the image placeholder in the rendered output.
|
||||
#[serde(default)]
|
||||
pub append_ocr_text: bool,
|
||||
}
|
||||
|
||||
/// Token reduction configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TokenReductionOptions {
|
||||
/// Reduction mode: "off", "light", "moderate", "aggressive", "maximum"
|
||||
#[serde(default = "default_reduction_mode")]
|
||||
pub mode: String,
|
||||
|
||||
/// Preserve important words (capitalized, technical terms)
|
||||
#[serde(default = "default_true")]
|
||||
pub preserve_important_words: bool,
|
||||
}
|
||||
|
||||
/// Language detection configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LanguageDetectionConfig {
|
||||
/// Enable language detection
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Minimum confidence threshold (0.0-1.0)
|
||||
#[serde(default = "default_confidence")]
|
||||
pub min_confidence: f64,
|
||||
|
||||
/// Detect multiple languages in the document
|
||||
#[serde(default)]
|
||||
pub detect_multiple: bool,
|
||||
}
|
||||
|
||||
impl Default for ImageExtractionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
extract_images: true,
|
||||
target_dpi: 300,
|
||||
max_image_dimension: 4096,
|
||||
inject_placeholders: true,
|
||||
auto_adjust_dpi: true,
|
||||
min_dpi: 72,
|
||||
max_dpi: 600,
|
||||
max_images_per_page: None,
|
||||
classify: true,
|
||||
include_page_rasters: false,
|
||||
run_ocr_on_images: true,
|
||||
ocr_text_only: false,
|
||||
append_ocr_text: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TokenReductionOptions {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
mode: default_reduction_mode(),
|
||||
preserve_important_words: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for LanguageDetectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
min_confidence: 0.8,
|
||||
detect_multiple: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default value functions
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_target_dpi() -> i32 {
|
||||
300
|
||||
}
|
||||
|
||||
fn default_max_dimension() -> i32 {
|
||||
4096
|
||||
}
|
||||
|
||||
fn default_min_dpi() -> i32 {
|
||||
72
|
||||
}
|
||||
|
||||
fn default_max_dpi() -> i32 {
|
||||
600
|
||||
}
|
||||
|
||||
fn default_reduction_mode() -> String {
|
||||
"off".to_string()
|
||||
}
|
||||
|
||||
fn default_confidence() -> f64 {
|
||||
0.8
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_image_extraction_config_default_booleans_are_true() {
|
||||
let cfg = ImageExtractionConfig::default();
|
||||
assert!(cfg.extract_images, "extract_images must default to true");
|
||||
assert!(cfg.inject_placeholders, "inject_placeholders must default to true");
|
||||
assert!(cfg.auto_adjust_dpi, "auto_adjust_dpi must default to true");
|
||||
assert!(cfg.classify, "classify must default to true");
|
||||
assert_eq!(cfg.target_dpi, 300);
|
||||
assert_eq!(cfg.max_image_dimension, 4096);
|
||||
assert_eq!(cfg.min_dpi, 72);
|
||||
assert_eq!(cfg.max_dpi, 600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_extraction_config_defaults() {
|
||||
let cfg = ImageExtractionConfig::default();
|
||||
assert!(cfg.run_ocr_on_images, "run_ocr_on_images must default to true");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_extraction_config_explicit_false_disables_placeholders() {
|
||||
let cfg = ImageExtractionConfig {
|
||||
inject_placeholders: false,
|
||||
..ImageExtractionConfig::default()
|
||||
};
|
||||
assert!(!cfg.inject_placeholders);
|
||||
assert!(cfg.extract_images);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_extraction_config_explicit_false_disables_classify() {
|
||||
let cfg = ImageExtractionConfig {
|
||||
classify: false,
|
||||
..ImageExtractionConfig::default()
|
||||
};
|
||||
assert!(!cfg.classify);
|
||||
assert!(cfg.extract_images);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_image_extraction_config_absent_json_fields_get_canonical_defaults() {
|
||||
let json = r#"{"extract_images": true}"#;
|
||||
let cfg: ImageExtractionConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(
|
||||
cfg.inject_placeholders,
|
||||
"absent inject_placeholders must deserialize to true"
|
||||
);
|
||||
assert!(cfg.auto_adjust_dpi, "absent auto_adjust_dpi must deserialize to true");
|
||||
assert_eq!(cfg.target_dpi, 300);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_images_per_page_defaults_none() {
|
||||
let config = ImageExtractionConfig::default();
|
||||
assert_eq!(config.max_images_per_page, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_images_per_page_serializes_as_null_when_none() {
|
||||
let config = ImageExtractionConfig::default();
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
assert!(json.contains("\"max_images_per_page\":null"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_images_per_page_roundtrips_via_json() {
|
||||
let config = ImageExtractionConfig {
|
||||
max_images_per_page: Some(50),
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let back: ImageExtractionConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back.max_images_per_page, Some(50));
|
||||
}
|
||||
|
||||
/// Regression test for issue #766: missing field in JSON must not break
|
||||
/// deserialization (backwards-compat — existing configs without this key
|
||||
/// must still deserialize cleanly).
|
||||
#[test]
|
||||
fn test_max_images_per_page_absent_in_json_deserializes_as_none() {
|
||||
let json = r#"{"extract_images":true,"target_dpi":300,"max_image_dimension":4096,
|
||||
"inject_placeholders":true,"auto_adjust_dpi":true,
|
||||
"min_dpi":72,"max_dpi":600}"#;
|
||||
let config: ImageExtractionConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.max_images_per_page, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_include_page_rasters_defaults_false() {
|
||||
let config = ImageExtractionConfig::default();
|
||||
assert!(
|
||||
!config.include_page_rasters,
|
||||
"include_page_rasters must default to false"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_include_page_rasters_absent_in_json_deserializes_as_false() {
|
||||
let json = r#"{"extract_images":true,"target_dpi":300,"max_image_dimension":4096,
|
||||
"inject_placeholders":true,"auto_adjust_dpi":true,
|
||||
"min_dpi":72,"max_dpi":600}"#;
|
||||
let config: ImageExtractionConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(
|
||||
!config.include_page_rasters,
|
||||
"absent include_page_rasters must deserialize to false (backward compat)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_include_page_rasters_roundtrips_via_json() {
|
||||
let config = ImageExtractionConfig {
|
||||
include_page_rasters: true,
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let back: ImageExtractionConfig = serde_json::from_str(&json).unwrap();
|
||||
assert!(back.include_page_rasters);
|
||||
}
|
||||
}
|
||||
201
crates/kreuzberg/src/core/config/formats.rs
Normal file
201
crates/kreuzberg/src/core/config/formats.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
//! Output format configuration and validation.
|
||||
//!
|
||||
//! This module defines the `OutputFormat` enum for controlling how extraction
|
||||
//! results are formatted (plain text, markdown, HTML, etc.) and provides
|
||||
//! serialization/deserialization support.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::str::FromStr;
|
||||
|
||||
/// Output format for extraction results.
|
||||
///
|
||||
/// Controls the format of the `content` field in `ExtractionResult`.
|
||||
/// When set to `Markdown`, `Djot`, or `Html`, the output uses that format.
|
||||
/// `Plain` returns the raw extracted text.
|
||||
/// `Structured` returns JSON with full OCR element data including bounding
|
||||
/// boxes and confidence scores.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum OutputFormat {
|
||||
/// Plain text content only (default)
|
||||
#[default]
|
||||
Plain,
|
||||
/// Markdown format
|
||||
Markdown,
|
||||
/// Djot markup format
|
||||
Djot,
|
||||
/// HTML format
|
||||
Html,
|
||||
/// JSON tree format with heading-driven sections.
|
||||
Json,
|
||||
/// Structured JSON format with full OCR element metadata.
|
||||
Structured,
|
||||
/// Custom renderer registered via the RendererRegistry.
|
||||
/// The string is the renderer name (e.g., "docx", "latex").
|
||||
#[serde(untagged)]
|
||||
Custom(String),
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
impl OutputFormat {
|
||||
/// Get the renderer name for this format.
|
||||
/// Returns `None` for formats that don't use the renderer registry
|
||||
/// (Plain, Structured, Toon — these are handled differently).
|
||||
pub(crate) fn renderer_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
OutputFormat::Plain | OutputFormat::Json | OutputFormat::Structured => None,
|
||||
OutputFormat::Markdown => Some("markdown"),
|
||||
OutputFormat::Djot => Some("djot"),
|
||||
OutputFormat::Html => Some("html"),
|
||||
OutputFormat::Custom(name) => Some(name.as_str()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for OutputFormat {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
OutputFormat::Plain => write!(f, "plain"),
|
||||
OutputFormat::Markdown => write!(f, "markdown"),
|
||||
OutputFormat::Djot => write!(f, "djot"),
|
||||
OutputFormat::Html => write!(f, "html"),
|
||||
OutputFormat::Json => write!(f, "json"),
|
||||
OutputFormat::Structured => write!(f, "structured"),
|
||||
OutputFormat::Custom(name) => write!(f, "{}", name),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for OutputFormat {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
|
||||
match s.to_lowercase().as_str() {
|
||||
"plain" | "text" => Ok(OutputFormat::Plain),
|
||||
"markdown" | "md" => Ok(OutputFormat::Markdown),
|
||||
"djot" => Ok(OutputFormat::Djot),
|
||||
"html" => Ok(OutputFormat::Html),
|
||||
"json" => Ok(OutputFormat::Json),
|
||||
"structured" | "structured-ocr" => Ok(OutputFormat::Structured),
|
||||
other => Ok(OutputFormat::Custom(other.to_string())),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_output_format_from_str_plain() {
|
||||
assert_eq!("plain".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
|
||||
assert_eq!("PLAIN".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
|
||||
assert_eq!("text".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
|
||||
assert_eq!("TEXT".parse::<OutputFormat>().unwrap(), OutputFormat::Plain);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_from_str_markdown() {
|
||||
assert_eq!("markdown".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
|
||||
assert_eq!("MARKDOWN".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
|
||||
assert_eq!("md".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
|
||||
assert_eq!("MD".parse::<OutputFormat>().unwrap(), OutputFormat::Markdown);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_from_str_djot() {
|
||||
assert_eq!("djot".parse::<OutputFormat>().unwrap(), OutputFormat::Djot);
|
||||
assert_eq!("DJOT".parse::<OutputFormat>().unwrap(), OutputFormat::Djot);
|
||||
assert_eq!("Djot".parse::<OutputFormat>().unwrap(), OutputFormat::Djot);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_from_str_html() {
|
||||
assert_eq!("html".parse::<OutputFormat>().unwrap(), OutputFormat::Html);
|
||||
assert_eq!("HTML".parse::<OutputFormat>().unwrap(), OutputFormat::Html);
|
||||
assert_eq!("Html".parse::<OutputFormat>().unwrap(), OutputFormat::Html);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_from_str_json() {
|
||||
assert_eq!("json".parse::<OutputFormat>().unwrap(), OutputFormat::Json);
|
||||
assert_eq!("JSON".parse::<OutputFormat>().unwrap(), OutputFormat::Json);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_from_str_structured() {
|
||||
assert_eq!("structured".parse::<OutputFormat>().unwrap(), OutputFormat::Structured);
|
||||
assert_eq!("STRUCTURED".parse::<OutputFormat>().unwrap(), OutputFormat::Structured);
|
||||
assert_eq!(
|
||||
"structured-ocr".parse::<OutputFormat>().unwrap(),
|
||||
OutputFormat::Structured
|
||||
);
|
||||
assert_eq!(
|
||||
"STRUCTURED-OCR".parse::<OutputFormat>().unwrap(),
|
||||
OutputFormat::Structured
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_from_str_custom() {
|
||||
let result = "docx".parse::<OutputFormat>().unwrap();
|
||||
assert_eq!(result, OutputFormat::Custom("docx".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_to_string() {
|
||||
assert_eq!(OutputFormat::Plain.to_string(), "plain");
|
||||
assert_eq!(OutputFormat::Markdown.to_string(), "markdown");
|
||||
assert_eq!(OutputFormat::Djot.to_string(), "djot");
|
||||
assert_eq!(OutputFormat::Html.to_string(), "html");
|
||||
assert_eq!(OutputFormat::Json.to_string(), "json");
|
||||
assert_eq!(OutputFormat::Structured.to_string(), "structured");
|
||||
assert_eq!(OutputFormat::Custom("docx".to_string()).to_string(), "docx");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_default() {
|
||||
let format = OutputFormat::default();
|
||||
assert_eq!(format, OutputFormat::Plain);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_serde_roundtrip() {
|
||||
for format in [
|
||||
OutputFormat::Plain,
|
||||
OutputFormat::Markdown,
|
||||
OutputFormat::Djot,
|
||||
OutputFormat::Html,
|
||||
OutputFormat::Json,
|
||||
OutputFormat::Structured,
|
||||
] {
|
||||
let json = serde_json::to_string(&format).unwrap();
|
||||
let deserialized: OutputFormat = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(format, deserialized);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_serde_values() {
|
||||
assert_eq!(serde_json::to_string(&OutputFormat::Plain).unwrap(), "\"plain\"");
|
||||
assert_eq!(serde_json::to_string(&OutputFormat::Markdown).unwrap(), "\"markdown\"");
|
||||
assert_eq!(serde_json::to_string(&OutputFormat::Djot).unwrap(), "\"djot\"");
|
||||
assert_eq!(serde_json::to_string(&OutputFormat::Html).unwrap(), "\"html\"");
|
||||
assert_eq!(serde_json::to_string(&OutputFormat::Json).unwrap(), "\"json\"");
|
||||
assert_eq!(
|
||||
serde_json::to_string(&OutputFormat::Structured).unwrap(),
|
||||
"\"structured\""
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_output_format_renderer_name() {
|
||||
assert_eq!(OutputFormat::Plain.renderer_name(), None);
|
||||
assert_eq!(OutputFormat::Markdown.renderer_name(), Some("markdown"));
|
||||
assert_eq!(OutputFormat::Html.renderer_name(), Some("html"));
|
||||
assert_eq!(OutputFormat::Djot.renderer_name(), Some("djot"));
|
||||
assert_eq!(OutputFormat::Json.renderer_name(), None);
|
||||
assert_eq!(OutputFormat::Structured.renderer_name(), None);
|
||||
assert_eq!(OutputFormat::Custom("docx".to_string()).renderer_name(), Some("docx"));
|
||||
}
|
||||
}
|
||||
136
crates/kreuzberg/src/core/config/html_output.rs
Normal file
136
crates/kreuzberg/src/core/config/html_output.rs
Normal file
@@ -0,0 +1,136 @@
|
||||
//! HTML output configuration.
|
||||
//!
|
||||
//! Controls how `OutputFormat::Html` renders an `InternalDocument`:
|
||||
//! which built-in theme to use, whether to embed the CSS in a `<style>`
|
||||
//! block, and optional user-supplied CSS (inline string or file path).
|
||||
|
||||
use std::path::PathBuf;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
fn default_class_prefix() -> String {
|
||||
"kb-".to_string()
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Configuration for styled HTML output.
|
||||
///
|
||||
/// When set on [`ExtractionConfig::html_output`] alongside
|
||||
/// `output_format = OutputFormat::Html`, the pipeline builds a
|
||||
/// [`StyledHtmlRenderer`](crate::rendering::StyledHtmlRenderer) instead of
|
||||
/// the plain comrak-based renderer.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config::{HtmlOutputConfig, HtmlTheme};
|
||||
///
|
||||
/// let config = HtmlOutputConfig {
|
||||
/// theme: HtmlTheme::GitHub,
|
||||
/// css: Some(".kb-p { font-size: 1.1rem; }".to_string()),
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HtmlOutputConfig {
|
||||
/// Inline CSS string injected into the output after the theme stylesheet.
|
||||
/// Concatenated after `css_file` content when both are set.
|
||||
#[serde(default)]
|
||||
pub css: Option<String>,
|
||||
|
||||
/// Path to a CSS file loaded once at renderer construction time.
|
||||
/// Concatenated before `css` when both are set.
|
||||
#[serde(default)]
|
||||
pub css_file: Option<PathBuf>,
|
||||
|
||||
/// Built-in colour/typography theme. Default: [`HtmlTheme::Unstyled`].
|
||||
#[serde(default)]
|
||||
pub theme: HtmlTheme,
|
||||
|
||||
/// CSS class prefix applied to every emitted class name.
|
||||
///
|
||||
/// Default: `"kb-"`. Change this if your host application already uses
|
||||
/// classes that start with `kb-`.
|
||||
#[serde(default = "default_class_prefix")]
|
||||
pub class_prefix: String,
|
||||
|
||||
/// When `true` (default), write the resolved CSS into a `<style>` block
|
||||
/// immediately after the opening `<div class="{prefix}doc">`.
|
||||
///
|
||||
/// Set to `false` to emit only the structural markup and wire up your
|
||||
/// own stylesheet targeting the `kb-*` class names.
|
||||
#[serde(default = "default_true")]
|
||||
pub embed_css: bool,
|
||||
}
|
||||
|
||||
impl Default for HtmlOutputConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
css: None,
|
||||
css_file: None,
|
||||
theme: HtmlTheme::Unstyled,
|
||||
class_prefix: default_class_prefix(),
|
||||
embed_css: true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Built-in HTML theme selection.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum HtmlTheme {
|
||||
/// Sensible defaults: system font stack, neutral colours, readable line
|
||||
/// measure. CSS custom properties (`--kb-*`) are all defined so user CSS
|
||||
/// can override individual values.
|
||||
Default,
|
||||
/// GitHub Markdown-inspired palette and spacing.
|
||||
GitHub,
|
||||
/// Dark background, light text.
|
||||
Dark,
|
||||
/// Minimal light theme with generous whitespace.
|
||||
Light,
|
||||
/// No built-in stylesheet emitted. CSS custom properties are still defined
|
||||
/// on `:root` so user stylesheets can reference `var(--kb-*)` tokens.
|
||||
#[default]
|
||||
Unstyled,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_config_values() {
|
||||
let cfg = HtmlOutputConfig::default();
|
||||
assert_eq!(cfg.class_prefix, "kb-");
|
||||
assert!(cfg.embed_css);
|
||||
assert!(cfg.css.is_none());
|
||||
assert!(cfg.css_file.is_none());
|
||||
assert_eq!(cfg.theme, HtmlTheme::Unstyled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_roundtrip() {
|
||||
let cfg = HtmlOutputConfig {
|
||||
css: Some(".kb-p { color: red; }".to_string()),
|
||||
theme: HtmlTheme::GitHub,
|
||||
embed_css: false,
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&cfg).unwrap();
|
||||
let back: HtmlOutputConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(back.css, cfg.css);
|
||||
assert_eq!(back.theme, HtmlTheme::GitHub);
|
||||
assert!(!back.embed_css);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn theme_serde() {
|
||||
assert_eq!(serde_json::to_string(&HtmlTheme::GitHub).unwrap(), "\"github\"");
|
||||
let t: HtmlTheme = serde_json::from_str("\"dark\"").unwrap();
|
||||
assert_eq!(t, HtmlTheme::Dark);
|
||||
}
|
||||
}
|
||||
180
crates/kreuzberg/src/core/config/layout.rs
Normal file
180
crates/kreuzberg/src/core/config/layout.rs
Normal file
@@ -0,0 +1,180 @@
|
||||
//! Layout detection configuration.
|
||||
|
||||
use std::fmt;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Which table structure recognition model to use.
|
||||
///
|
||||
/// Controls the model used for table cell detection within layout-detected
|
||||
/// table regions. Wire format is snake_case in all serializers (JSON, TOML,
|
||||
/// YAML).
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TableModel {
|
||||
/// TATR (Table Transformer) -- default, 30MB, DETR-based row/column detection.
|
||||
#[default]
|
||||
Tatr,
|
||||
/// SLANeXT wired variant -- 365MB, optimized for bordered tables.
|
||||
SlanetWired,
|
||||
/// SLANeXT wireless variant -- 365MB, optimized for borderless tables.
|
||||
SlanetWireless,
|
||||
/// SLANet-plus -- 7.78MB, lightweight general-purpose.
|
||||
SlanetPlus,
|
||||
/// Classifier-routed SLANeXT: auto-select wired/wireless per table.
|
||||
/// Uses PP-LCNet classifier (6.78MB) + both SLANeXT variants (730MB total).
|
||||
SlanetAuto,
|
||||
/// Disable table structure model inference entirely; use heuristic path only.
|
||||
Disabled,
|
||||
}
|
||||
|
||||
impl std::str::FromStr for TableModel {
|
||||
type Err = String;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
match s {
|
||||
"tatr" => Ok(Self::Tatr),
|
||||
"slanet_wired" => Ok(Self::SlanetWired),
|
||||
"slanet_wireless" => Ok(Self::SlanetWireless),
|
||||
"slanet_plus" => Ok(Self::SlanetPlus),
|
||||
"slanet_auto" => Ok(Self::SlanetAuto),
|
||||
"disabled" => Ok(Self::Disabled),
|
||||
other => Err(format!(
|
||||
"unknown table model: '{other}'. Valid: tatr, slanet_wired, slanet_wireless, slanet_plus, slanet_auto, disabled"
|
||||
)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for TableModel {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
TableModel::Tatr => write!(f, "tatr"),
|
||||
TableModel::SlanetWired => write!(f, "slanet_wired"),
|
||||
TableModel::SlanetWireless => write!(f, "slanet_wireless"),
|
||||
TableModel::SlanetPlus => write!(f, "slanet_plus"),
|
||||
TableModel::SlanetAuto => write!(f, "slanet_auto"),
|
||||
TableModel::Disabled => write!(f, "disabled"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Layout detection configuration.
|
||||
///
|
||||
/// Controls layout detection behavior in the extraction pipeline.
|
||||
/// When set on [`ExtractionConfig`](super::ExtractionConfig), layout detection
|
||||
/// is enabled for PDF extraction.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LayoutDetectionConfig {
|
||||
/// Confidence threshold override (None = use model default).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub confidence_threshold: Option<f32>,
|
||||
|
||||
/// Whether to apply postprocessing heuristics (default: true).
|
||||
#[serde(default = "default_true")]
|
||||
pub apply_heuristics: bool,
|
||||
|
||||
/// Table structure recognition model.
|
||||
///
|
||||
/// Controls which model is used for table cell detection within layout-detected
|
||||
/// table regions. Defaults to [`TableModel::Tatr`].
|
||||
#[serde(default)]
|
||||
pub table_model: TableModel,
|
||||
|
||||
/// Hardware acceleration for ONNX models (layout detection + table structure).
|
||||
///
|
||||
/// When set, controls which execution provider (CPU, CUDA, CoreML, TensorRT)
|
||||
/// is used for inference. Defaults to `None` (auto-select per platform).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub acceleration: Option<super::acceleration::AccelerationConfig>,
|
||||
}
|
||||
|
||||
impl Default for LayoutDetectionConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
confidence_threshold: None,
|
||||
apply_heuristics: true,
|
||||
table_model: TableModel::default(),
|
||||
acceleration: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = LayoutDetectionConfig::default();
|
||||
assert_eq!(config.table_model, TableModel::Tatr);
|
||||
assert!(config.apply_heuristics);
|
||||
assert!(config.confidence_threshold.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_model_deserialize() {
|
||||
let json = r#""tatr""#;
|
||||
let model: TableModel = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(model, TableModel::Tatr);
|
||||
|
||||
let json = r#""slanet_auto""#;
|
||||
let model: TableModel = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(model, TableModel::SlanetAuto);
|
||||
|
||||
let json = r#""disabled""#;
|
||||
let model: TableModel = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(model, TableModel::Disabled);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_model_serialize() {
|
||||
let json = serde_json::to_string(&TableModel::SlanetWired).unwrap();
|
||||
assert_eq!(json, r#""slanet_wired""#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_model_round_trip() {
|
||||
for model in [
|
||||
TableModel::Tatr,
|
||||
TableModel::SlanetWired,
|
||||
TableModel::SlanetWireless,
|
||||
TableModel::SlanetPlus,
|
||||
TableModel::SlanetAuto,
|
||||
TableModel::Disabled,
|
||||
] {
|
||||
let serialized = serde_json::to_string(&model).unwrap();
|
||||
let parsed: TableModel = serde_json::from_str(&serialized).unwrap();
|
||||
assert_eq!(parsed, model, "round-trip failed for {model:?}");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backward_compat_unknown_fields_ignored() {
|
||||
// Old configs with "preset" field should still deserialize because
|
||||
// serde ignores unknown fields by default.
|
||||
let json = r#"{"preset": "accurate", "apply_heuristics": true}"#;
|
||||
let config: LayoutDetectionConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(config.apply_heuristics);
|
||||
assert_eq!(config.table_model, TableModel::Tatr);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_backward_compat_old_table_model_field() {
|
||||
// Old configs with table_model as a string should still work
|
||||
let json = r#"{"table_model": "slanet_wired"}"#;
|
||||
let config: LayoutDetectionConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.table_model, TableModel::SlanetWired);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_table_model_display() {
|
||||
assert_eq!(TableModel::Tatr.to_string(), "tatr");
|
||||
assert_eq!(TableModel::SlanetWired.to_string(), "slanet_wired");
|
||||
assert_eq!(TableModel::Disabled.to_string(), "disabled");
|
||||
}
|
||||
}
|
||||
154
crates/kreuzberg/src/core/config/llm.rs
Normal file
154
crates/kreuzberg/src/core/config/llm.rs
Normal file
@@ -0,0 +1,154 @@
|
||||
//! LLM configuration types for liter-llm integration.
|
||||
//!
|
||||
//! These types are always available (not feature-gated) since they are
|
||||
//! pure configuration data with no runtime dependency on liter-llm.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Configuration for an LLM provider/model via liter-llm.
|
||||
///
|
||||
/// Each feature (VLM OCR, VLM embeddings, structured extraction) carries
|
||||
/// its own `LlmConfig`, allowing different providers per feature.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```toml
|
||||
/// [structured_extraction.llm]
|
||||
/// model = "openai/gpt-4o"
|
||||
/// api_key = "sk-..." # or use KREUZBERG_LLM_API_KEY env var
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct LlmConfig {
|
||||
/// Provider/model string using liter-llm routing format.
|
||||
///
|
||||
/// Examples: `"openai/gpt-4o"`, `"anthropic/claude-sonnet-4-20250514"`,
|
||||
/// `"groq/llama-3.1-70b-versatile"`.
|
||||
pub model: String,
|
||||
|
||||
/// API key for the provider. When `None`, liter-llm falls back to
|
||||
/// the provider's standard environment variable (e.g., `OPENAI_API_KEY`).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub api_key: Option<String>,
|
||||
|
||||
/// Custom base URL override for the provider endpoint.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub base_url: Option<String>,
|
||||
|
||||
/// Request timeout in seconds (default: 60).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub timeout_secs: Option<u64>,
|
||||
|
||||
/// Maximum retry attempts (default: 3).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_retries: Option<u32>,
|
||||
|
||||
/// Sampling temperature for generation tasks.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub temperature: Option<f64>,
|
||||
|
||||
/// Maximum tokens to generate.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_tokens: Option<u64>,
|
||||
}
|
||||
|
||||
/// Configuration for LLM-based structured data extraction.
|
||||
///
|
||||
/// Sends extracted document content to a VLM with a JSON schema,
|
||||
/// returning structured data that conforms to the schema.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```toml
|
||||
/// [structured_extraction]
|
||||
/// schema_name = "invoice_data"
|
||||
/// strict = true
|
||||
///
|
||||
/// [structured_extraction.schema]
|
||||
/// type = "object"
|
||||
/// properties.vendor = { type = "string" }
|
||||
/// properties.total = { type = "number" }
|
||||
/// required = ["vendor", "total"]
|
||||
///
|
||||
/// [structured_extraction.llm]
|
||||
/// model = "openai/gpt-4o"
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StructuredExtractionConfig {
|
||||
/// JSON Schema defining the desired output structure.
|
||||
pub schema: serde_json::Value,
|
||||
|
||||
/// Schema name passed to the LLM's structured output mode.
|
||||
#[serde(default = "default_schema_name")]
|
||||
pub schema_name: String,
|
||||
|
||||
/// Optional schema description for the LLM.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub schema_description: Option<String>,
|
||||
|
||||
/// Enable strict mode — output must exactly match the schema.
|
||||
#[serde(default)]
|
||||
pub strict: bool,
|
||||
|
||||
/// Custom Jinja2 extraction prompt template. When `None`, a default template is used.
|
||||
///
|
||||
/// Available template variables:
|
||||
/// - `{{ content }}` — The extracted document text.
|
||||
/// - `{{ schema }}` — The JSON schema as a formatted string.
|
||||
/// - `{{ schema_name }}` — The schema name.
|
||||
/// - `{{ schema_description }}` — The schema description (may be empty).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub prompt: Option<String>,
|
||||
|
||||
/// LLM configuration for the extraction.
|
||||
pub llm: LlmConfig,
|
||||
}
|
||||
|
||||
fn default_schema_name() -> String {
|
||||
"extraction".to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
/// Regression test for https://github.com/kreuzberg-dev/kreuzberg/issues/716
|
||||
///
|
||||
/// `LlmConfig` must implement `Default` so callers can use the struct-update
|
||||
/// syntax documented in the VLM OCR guide:
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config::LlmConfig;
|
||||
/// let cfg = LlmConfig {
|
||||
/// model: "openai/gpt-4o-mini".to_string(),
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// ```
|
||||
#[test]
|
||||
fn test_llm_config_default_trait_is_satisfied() {
|
||||
let cfg = LlmConfig::default();
|
||||
assert!(cfg.model.is_empty(), "default model should be empty string");
|
||||
assert!(cfg.api_key.is_none());
|
||||
assert!(cfg.base_url.is_none());
|
||||
assert!(cfg.timeout_secs.is_none());
|
||||
assert!(cfg.max_retries.is_none());
|
||||
assert!(cfg.temperature.is_none());
|
||||
assert!(cfg.max_tokens.is_none());
|
||||
}
|
||||
|
||||
/// Verify the struct-update pattern from the issue compiles and produces
|
||||
/// only the explicitly set field.
|
||||
#[test]
|
||||
fn test_llm_config_struct_update_syntax() {
|
||||
let cfg = LlmConfig {
|
||||
model: "openai/gpt-4o-mini".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(cfg.model, "openai/gpt-4o-mini");
|
||||
assert!(cfg.api_key.is_none());
|
||||
assert!(cfg.base_url.is_none());
|
||||
assert!(cfg.timeout_secs.is_none());
|
||||
assert!(cfg.max_retries.is_none());
|
||||
assert!(cfg.temperature.is_none());
|
||||
assert!(cfg.max_tokens.is_none());
|
||||
}
|
||||
}
|
||||
150
crates/kreuzberg/src/core/config/merge.rs
Normal file
150
crates/kreuzberg/src/core/config/merge.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
//! JSON-level configuration merging.
|
||||
//!
|
||||
//! Provides a unified merge function for combining a base `ExtractionConfig` with
|
||||
//! JSON overrides. Used by both the CLI (`--config-json`) and MCP server to apply
|
||||
//! partial configuration overrides without losing unspecified fields.
|
||||
|
||||
use super::ExtractionConfig;
|
||||
|
||||
/// Merge extraction configuration using JSON-level field override.
|
||||
///
|
||||
/// Serializes the base config to JSON, merges each field from the override JSON
|
||||
/// (top-level only), and deserializes back. This correctly handles boolean fields
|
||||
/// explicitly set to their default values — the override always wins for any field
|
||||
/// present in `override_json`.
|
||||
///
|
||||
/// Fields **not** present in `override_json` are preserved from `base`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `Err` if the base config cannot be serialized, or if the merged JSON
|
||||
/// cannot be deserialized back into `ExtractionConfig` (e.g., wrong field types).
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use kreuzberg::ExtractionConfig;
|
||||
/// use serde_json::json;
|
||||
///
|
||||
/// let mut base = ExtractionConfig::default();
|
||||
/// base.use_cache = true;
|
||||
///
|
||||
/// let overrides = r#"{"force_ocr": true}"#;
|
||||
/// let merged = kreuzberg::core::config::merge::merge_config_json(&base, overrides).unwrap();
|
||||
/// assert!(merged.use_cache); // preserved from base
|
||||
/// assert!(merged.force_ocr); // applied from override
|
||||
/// ```
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn merge_config_json(base: &ExtractionConfig, override_json: &str) -> Result<ExtractionConfig, String> {
|
||||
let override_value: serde_json::Value =
|
||||
serde_json::from_str(override_json).map_err(|e| format!("Failed to parse override JSON: {e}"))?;
|
||||
|
||||
let mut config_json =
|
||||
serde_json::to_value(base).map_err(|e| format!("Failed to serialize base config to JSON: {e}"))?;
|
||||
|
||||
if let serde_json::Value::Object(json_obj) = override_value
|
||||
&& let Some(config_obj) = config_json.as_object_mut()
|
||||
{
|
||||
for (key, value) in json_obj {
|
||||
config_obj.insert(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
serde_json::from_value(config_json).map_err(|e| format!("Failed to deserialize merged config: {e}"))
|
||||
}
|
||||
|
||||
/// Build extraction config by optionally merging JSON overrides into a base config.
|
||||
///
|
||||
/// If `override_json` is `None`, returns a clone of `base`. Otherwise delegates
|
||||
/// to [`merge_config_json`].
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn build_config_from_json(
|
||||
base: &ExtractionConfig,
|
||||
override_json: Option<&str>,
|
||||
) -> Result<ExtractionConfig, String> {
|
||||
match override_json {
|
||||
Some(json) => merge_config_json(base, json),
|
||||
None => Ok(base.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_merge_preserves_unspecified_fields() {
|
||||
let base = ExtractionConfig {
|
||||
use_cache: false,
|
||||
enable_quality_processing: true,
|
||||
force_ocr: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let merged = merge_config_json(&base, r#"{"force_ocr": true}"#).unwrap();
|
||||
|
||||
assert!(!merged.use_cache, "use_cache should be preserved from base");
|
||||
assert!(
|
||||
merged.enable_quality_processing,
|
||||
"enable_quality_processing should be preserved"
|
||||
);
|
||||
assert!(merged.force_ocr, "force_ocr should be overridden");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_override_to_default_value() {
|
||||
let base = ExtractionConfig {
|
||||
use_cache: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let merged = merge_config_json(&base, r#"{"use_cache": true}"#).unwrap();
|
||||
assert!(
|
||||
merged.use_cache,
|
||||
"Should use explicit override even if it matches the struct default"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_multiple_fields() {
|
||||
let base = ExtractionConfig {
|
||||
use_cache: true,
|
||||
force_ocr: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let merged = merge_config_json(&base, r#"{"use_cache": false, "output_format": "markdown"}"#).unwrap();
|
||||
|
||||
assert!(!merged.use_cache);
|
||||
assert!(merged.force_ocr, "force_ocr should be preserved");
|
||||
assert_eq!(
|
||||
merged.output_format,
|
||||
crate::core::config::formats::OutputFormat::Markdown,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_merge_invalid_field_type_returns_error() {
|
||||
let base = ExtractionConfig::default();
|
||||
let result = merge_config_json(&base, r#"{"use_cache": "not_a_boolean"}"#);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Failed to deserialize"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_config_from_json_none_returns_clone() {
|
||||
let base = ExtractionConfig {
|
||||
use_cache: false,
|
||||
..Default::default()
|
||||
};
|
||||
let result = build_config_from_json(&base, None).unwrap();
|
||||
assert!(!result.use_cache);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_config_from_json_some_merges() {
|
||||
let base = ExtractionConfig::default();
|
||||
let result = build_config_from_json(&base, Some(r#"{"force_ocr": true}"#)).unwrap();
|
||||
assert!(result.force_ocr);
|
||||
}
|
||||
}
|
||||
47
crates/kreuzberg/src/core/config/mod.rs
Normal file
47
crates/kreuzberg/src/core/config/mod.rs
Normal file
@@ -0,0 +1,47 @@
|
||||
//! Configuration loading and management.
|
||||
//!
|
||||
//! This module provides utilities for loading extraction configuration from various
|
||||
//! sources (TOML, YAML, JSON) and discovering configuration files in the project hierarchy.
|
||||
|
||||
pub mod acceleration;
|
||||
pub mod concurrency;
|
||||
pub mod content_filter;
|
||||
pub mod email;
|
||||
pub mod extraction;
|
||||
pub mod formats;
|
||||
#[cfg(feature = "html")]
|
||||
pub mod html_output;
|
||||
pub mod layout;
|
||||
pub mod llm;
|
||||
pub mod merge;
|
||||
pub mod ocr;
|
||||
pub mod page;
|
||||
pub mod pdf;
|
||||
pub mod processing;
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
pub mod tree_sitter;
|
||||
|
||||
// Re-export main types for backward compatibility
|
||||
pub use acceleration::{AccelerationConfig, ExecutionProviderType};
|
||||
pub use concurrency::ConcurrencyConfig;
|
||||
pub use content_filter::ContentFilterConfig;
|
||||
pub use email::EmailConfig;
|
||||
pub use extraction::{
|
||||
BatchBytesItem, BatchFileItem, ExtractionConfig, FileExtractionConfig, ImageExtractionConfig,
|
||||
LanguageDetectionConfig, TokenReductionOptions,
|
||||
};
|
||||
pub use formats::OutputFormat;
|
||||
#[cfg(feature = "html")]
|
||||
pub use html_output::{HtmlOutputConfig, HtmlTheme};
|
||||
#[cfg(feature = "layout-types")]
|
||||
pub use layout::{LayoutDetectionConfig, TableModel};
|
||||
pub use llm::{LlmConfig, StructuredExtractionConfig};
|
||||
pub use ocr::{OcrConfig, OcrPipelineConfig, OcrPipelineStage, OcrQualityThresholds};
|
||||
pub use page::PageConfig;
|
||||
#[cfg(feature = "pdf")]
|
||||
pub use pdf::{HierarchyConfig, PdfConfig};
|
||||
pub use processing::{
|
||||
ChunkSizing, ChunkerType, ChunkingConfig, EmbeddingConfig, EmbeddingModelType, PostProcessorConfig,
|
||||
};
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
pub use tree_sitter::{CodeContentMode, TreeSitterConfig, TreeSitterProcessConfig};
|
||||
932
crates/kreuzberg/src/core/config/ocr.rs
Normal file
932
crates/kreuzberg/src/core/config/ocr.rs
Normal file
@@ -0,0 +1,932 @@
|
||||
//! OCR configuration.
|
||||
//!
|
||||
//! Defines OCR-specific configuration including backend selection, language settings,
|
||||
//! Tesseract-specific parameters, quality thresholds, and multi-backend pipeline config.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use super::formats::OutputFormat;
|
||||
#[cfg(test)]
|
||||
use crate::core::config_validation::validate_ocr_backend;
|
||||
#[cfg(test)]
|
||||
use crate::error::KreuzbergError;
|
||||
use crate::types::OcrElementConfig;
|
||||
|
||||
/// Quality thresholds for OCR fallback decisions and pipeline quality gating.
|
||||
///
|
||||
/// All fields default to the values that match the previous hardcoded behavior,
|
||||
/// so `OcrQualityThresholds::default()` preserves existing semantics exactly.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrQualityThresholds {
|
||||
/// Minimum total non-whitespace characters to consider text substantive.
|
||||
#[serde(default = "default_min_total_non_whitespace")]
|
||||
pub min_total_non_whitespace: usize,
|
||||
|
||||
/// Minimum non-whitespace characters per page on average.
|
||||
#[serde(default = "default_min_non_whitespace_per_page")]
|
||||
pub min_non_whitespace_per_page: f64,
|
||||
|
||||
/// Minimum character count for a word to be "meaningful".
|
||||
#[serde(default = "default_min_meaningful_word_len")]
|
||||
pub min_meaningful_word_len: usize,
|
||||
|
||||
/// Minimum count of meaningful words before text is accepted.
|
||||
#[serde(default = "default_min_meaningful_words")]
|
||||
pub min_meaningful_words: usize,
|
||||
|
||||
/// Minimum alphanumeric ratio (non-whitespace chars that are alphanumeric).
|
||||
#[serde(default = "default_min_alnum_ratio")]
|
||||
pub min_alnum_ratio: f64,
|
||||
|
||||
/// Minimum Unicode replacement characters (U+FFFD) to trigger OCR fallback.
|
||||
#[serde(default = "default_min_garbage_chars")]
|
||||
pub min_garbage_chars: usize,
|
||||
|
||||
/// Maximum fraction of short (1-2 char) words before text is considered fragmented.
|
||||
#[serde(default = "default_max_fragmented_word_ratio")]
|
||||
pub max_fragmented_word_ratio: f64,
|
||||
|
||||
/// Critical fragmentation threshold — triggers OCR regardless of meaningful words.
|
||||
/// Normal English text has ~20-30% short words. 80%+ is definitive garbage.
|
||||
#[serde(default = "default_critical_fragmented_word_ratio")]
|
||||
pub critical_fragmented_word_ratio: f64,
|
||||
|
||||
/// Minimum average word length. Below this with enough words indicates garbled extraction.
|
||||
#[serde(default = "default_min_avg_word_length")]
|
||||
pub min_avg_word_length: f64,
|
||||
|
||||
/// Minimum word count before average word length check applies.
|
||||
#[serde(default = "default_min_words_for_avg_length_check")]
|
||||
pub min_words_for_avg_length_check: usize,
|
||||
|
||||
/// Minimum consecutive word repetition ratio to detect column scrambling.
|
||||
#[serde(default = "default_min_consecutive_repeat_ratio")]
|
||||
pub min_consecutive_repeat_ratio: f64,
|
||||
|
||||
/// Minimum word count before consecutive repetition check is applied.
|
||||
#[serde(default = "default_min_words_for_repeat_check")]
|
||||
pub min_words_for_repeat_check: usize,
|
||||
|
||||
/// Minimum character count for "substantive markdown" OCR skip gate.
|
||||
#[serde(default = "default_substantive_min_chars")]
|
||||
pub substantive_min_chars: usize,
|
||||
|
||||
/// Minimum character count for "non-text content" OCR skip gate.
|
||||
#[serde(default = "default_non_text_min_chars")]
|
||||
pub non_text_min_chars: usize,
|
||||
|
||||
/// Alphanumeric+whitespace ratio threshold for skip decisions.
|
||||
#[serde(default = "default_alnum_ws_ratio_threshold")]
|
||||
pub alnum_ws_ratio_threshold: f64,
|
||||
|
||||
/// Minimum quality score (0.0-1.0) for a pipeline stage result to be accepted.
|
||||
/// If the result from a backend scores below this, try the next backend.
|
||||
#[serde(default = "default_pipeline_min_quality")]
|
||||
pub pipeline_min_quality: f64,
|
||||
}
|
||||
|
||||
impl Default for OcrQualityThresholds {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
min_total_non_whitespace: 64,
|
||||
min_non_whitespace_per_page: 32.0,
|
||||
min_meaningful_word_len: 4,
|
||||
min_meaningful_words: 3,
|
||||
min_alnum_ratio: 0.3,
|
||||
min_garbage_chars: 5,
|
||||
max_fragmented_word_ratio: 0.6,
|
||||
critical_fragmented_word_ratio: 0.80,
|
||||
min_avg_word_length: 2.0,
|
||||
min_words_for_avg_length_check: 50,
|
||||
min_consecutive_repeat_ratio: 0.08,
|
||||
min_words_for_repeat_check: 50,
|
||||
substantive_min_chars: 100,
|
||||
non_text_min_chars: 20,
|
||||
alnum_ws_ratio_threshold: 0.4,
|
||||
pipeline_min_quality: 0.5,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_min_total_non_whitespace() -> usize {
|
||||
64
|
||||
}
|
||||
fn default_min_non_whitespace_per_page() -> f64 {
|
||||
32.0
|
||||
}
|
||||
fn default_min_meaningful_word_len() -> usize {
|
||||
4
|
||||
}
|
||||
fn default_min_meaningful_words() -> usize {
|
||||
3
|
||||
}
|
||||
fn default_min_alnum_ratio() -> f64 {
|
||||
0.3
|
||||
}
|
||||
fn default_min_garbage_chars() -> usize {
|
||||
5
|
||||
}
|
||||
fn default_max_fragmented_word_ratio() -> f64 {
|
||||
0.6
|
||||
}
|
||||
fn default_critical_fragmented_word_ratio() -> f64 {
|
||||
0.80
|
||||
}
|
||||
fn default_min_avg_word_length() -> f64 {
|
||||
2.0
|
||||
}
|
||||
fn default_min_words_for_avg_length_check() -> usize {
|
||||
50
|
||||
}
|
||||
fn default_min_consecutive_repeat_ratio() -> f64 {
|
||||
0.08
|
||||
}
|
||||
fn default_min_words_for_repeat_check() -> usize {
|
||||
50
|
||||
}
|
||||
fn default_substantive_min_chars() -> usize {
|
||||
100
|
||||
}
|
||||
fn default_non_text_min_chars() -> usize {
|
||||
20
|
||||
}
|
||||
fn default_alnum_ws_ratio_threshold() -> f64 {
|
||||
0.4
|
||||
}
|
||||
fn default_pipeline_min_quality() -> f64 {
|
||||
0.5
|
||||
}
|
||||
|
||||
/// A single backend stage in the OCR pipeline.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrPipelineStage {
|
||||
/// Backend name: "tesseract", "paddleocr", "easyocr", or a custom registered name.
|
||||
pub backend: String,
|
||||
|
||||
/// Priority weight (higher = tried first). Stages are sorted by priority descending.
|
||||
#[serde(default = "default_priority")]
|
||||
pub priority: u32,
|
||||
|
||||
/// Language override for this stage (None = use parent OcrConfig.language).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub language: Option<String>,
|
||||
|
||||
/// Tesseract-specific config override for this stage.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub tesseract_config: Option<crate::types::TesseractConfig>,
|
||||
|
||||
/// PaddleOCR-specific config for this stage.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub paddle_ocr_config: Option<serde_json::Value>,
|
||||
|
||||
/// VLM config override for this pipeline stage.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub vlm_config: Option<super::llm::LlmConfig>,
|
||||
|
||||
/// Arbitrary per-call options passed through to the backend unchanged.
|
||||
///
|
||||
/// Backends that support runtime tuning (mode switching, preprocessing
|
||||
/// flags, inference parameters, etc.) read this value and deserialize
|
||||
/// the keys they care about. Keys unknown to the backend are silently
|
||||
/// ignored, so options from different backends can coexist in the same
|
||||
/// config without conflict.
|
||||
///
|
||||
/// Example (custom backend):
|
||||
/// ```json
|
||||
/// { "mode": "fast", "enable_layout": true }
|
||||
/// ```
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub backend_options: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
fn default_priority() -> u32 {
|
||||
100
|
||||
}
|
||||
|
||||
/// Multi-backend OCR pipeline with quality-based fallback.
|
||||
///
|
||||
/// Backends are tried in priority order (highest first). After each backend
|
||||
/// produces output, quality is evaluated. If it meets `quality_thresholds.pipeline_min_quality`,
|
||||
/// the result is accepted. Otherwise the next backend is tried.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrPipelineConfig {
|
||||
/// Ordered list of backends to try. Sorted by priority (descending) at runtime.
|
||||
pub stages: Vec<OcrPipelineStage>,
|
||||
|
||||
/// Quality thresholds for deciding whether to accept a result or try the next backend.
|
||||
#[serde(default)]
|
||||
pub quality_thresholds: OcrQualityThresholds,
|
||||
}
|
||||
|
||||
/// OCR configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OcrConfig {
|
||||
/// Whether OCR is enabled.
|
||||
///
|
||||
/// Setting `enabled: false` is a shorthand for `disable_ocr: true` on the parent
|
||||
/// [`ExtractionConfig`](crate::core::config::ExtractionConfig). Images return
|
||||
/// metadata only; PDFs use native text extraction without OCR fallback.
|
||||
///
|
||||
/// Defaults to `true`. When `false`, all other OCR settings are ignored.
|
||||
#[serde(default = "default_ocr_enabled")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// OCR backend: tesseract, easyocr, paddleocr
|
||||
#[serde(default = "default_tesseract_backend")]
|
||||
pub backend: String,
|
||||
|
||||
/// Language code (e.g., "eng", "deu")
|
||||
#[serde(default = "default_eng")]
|
||||
pub language: String,
|
||||
|
||||
/// Tesseract-specific configuration (optional)
|
||||
#[serde(default)]
|
||||
pub tesseract_config: Option<crate::types::TesseractConfig>,
|
||||
|
||||
/// Output format for OCR results (optional, for format conversion)
|
||||
#[serde(default)]
|
||||
pub output_format: Option<OutputFormat>,
|
||||
|
||||
/// PaddleOCR-specific configuration (optional, JSON passthrough)
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub paddle_ocr_config: Option<serde_json::Value>,
|
||||
|
||||
/// Arbitrary per-call options passed through to the backend unchanged.
|
||||
///
|
||||
/// Custom OCR backends and built-in backends that support runtime tuning
|
||||
/// can read this value and deserialize the keys they care about. Keys
|
||||
/// unknown to the backend are silently ignored.
|
||||
///
|
||||
/// This is the recommended extension point for per-call parameters that
|
||||
/// are not covered by the typed fields above (e.g. mode switching,
|
||||
/// preprocessing flags, inference batch size).
|
||||
///
|
||||
/// **Scope:** when `pipeline` is `None`, this value is propagated to the
|
||||
/// primary stage of the auto-constructed pipeline. When `pipeline` is
|
||||
/// explicitly set, this field has **no effect** — the caller must set
|
||||
/// `OcrPipelineStage.backend_options` directly on the relevant stage(s)
|
||||
/// instead.
|
||||
///
|
||||
/// Example:
|
||||
/// ```json
|
||||
/// { "mode": "fast", "enable_layout": true, "timeout_ms": 5000 }
|
||||
/// ```
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub backend_options: Option<serde_json::Value>,
|
||||
|
||||
/// OCR element extraction configuration
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub element_config: Option<OcrElementConfig>,
|
||||
|
||||
/// Quality thresholds for the native-text-to-OCR fallback decision.
|
||||
/// When None, uses compiled defaults (matching previous hardcoded behavior).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub quality_thresholds: Option<OcrQualityThresholds>,
|
||||
|
||||
/// Multi-backend OCR pipeline configuration. When set, enables weighted
|
||||
/// fallback across multiple OCR backends based on output quality.
|
||||
/// When None, uses the single `backend` field (same as today).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub pipeline: Option<OcrPipelineConfig>,
|
||||
|
||||
/// Enable automatic page rotation based on orientation detection.
|
||||
///
|
||||
/// When enabled, uses Tesseract's `DetectOrientationScript()` to detect
|
||||
/// page orientation (0/90/180/270 degrees) before OCR. If the page is
|
||||
/// rotated with high confidence, the image is corrected before recognition.
|
||||
/// This is critical for handling rotated scanned documents.
|
||||
#[serde(default)]
|
||||
pub auto_rotate: bool,
|
||||
|
||||
/// VLM (Vision Language Model) OCR configuration.
|
||||
///
|
||||
/// Required when `backend` is `"vlm"`. Uses liter-llm to send page
|
||||
/// images to a vision model for text extraction.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub vlm_config: Option<super::llm::LlmConfig>,
|
||||
|
||||
/// Custom Jinja2 prompt template for VLM OCR.
|
||||
///
|
||||
/// When `None`, uses the default template. Available variables:
|
||||
/// - `{{ language }}` — The document language code (e.g., "eng", "deu").
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub vlm_prompt: Option<String>,
|
||||
|
||||
/// Hardware acceleration for ONNX Runtime models (e.g. PaddleOCR, layout detection).
|
||||
///
|
||||
/// Not user-configurable via config files — injected at runtime from
|
||||
/// `ExtractionConfig::acceleration` before each `process_image` call.
|
||||
#[serde(skip)]
|
||||
pub acceleration: Option<super::acceleration::AccelerationConfig>,
|
||||
|
||||
/// Caller-supplied Tesseract `traineddata` bytes per language code.
|
||||
///
|
||||
/// Primary use case is the WASM build, which has no filesystem and cannot
|
||||
/// download tessdata at runtime. Native builds typically rely on
|
||||
/// `TessdataManager` and ignore this field. When present, the WASM
|
||||
/// Tesseract backend prefers these bytes over its compile-time-bundled
|
||||
/// English data.
|
||||
///
|
||||
/// Skipped by serde to keep config files small — supply via the typed API
|
||||
/// at runtime.
|
||||
#[serde(skip)]
|
||||
pub tessdata_bytes: Option<std::collections::HashMap<String, Vec<u8>>>,
|
||||
}
|
||||
|
||||
impl Default for OcrConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
backend: default_tesseract_backend(),
|
||||
language: default_eng(),
|
||||
tesseract_config: None,
|
||||
output_format: None,
|
||||
paddle_ocr_config: None,
|
||||
backend_options: None,
|
||||
element_config: None,
|
||||
quality_thresholds: None,
|
||||
pipeline: None,
|
||||
auto_rotate: false,
|
||||
vlm_config: None,
|
||||
vlm_prompt: None,
|
||||
acceleration: None,
|
||||
tessdata_bytes: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OcrConfig {
|
||||
/// Validates that the configured backend is supported.
|
||||
///
|
||||
/// This method checks that the backend name is one of the supported OCR backends:
|
||||
/// - tesseract
|
||||
/// - easyocr
|
||||
/// - paddleocr
|
||||
///
|
||||
/// Typos in backend names are caught at configuration validation time, not at runtime.
|
||||
/// Also validates pipeline stage backends when a pipeline is configured.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate(&self) -> Result<(), KreuzbergError> {
|
||||
validate_ocr_backend(&self.backend)?;
|
||||
// When backend is "vlm", vlm_config must be present.
|
||||
crate::core::config_validation::validate_vlm_backend_config(&self.backend, self.vlm_config.as_ref())?;
|
||||
if let Some(ref pipeline) = self.pipeline {
|
||||
for stage in &pipeline.stages {
|
||||
validate_ocr_backend(&stage.backend)?;
|
||||
crate::core::config_validation::validate_vlm_backend_config(&stage.backend, stage.vlm_config.as_ref())?;
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns the effective quality thresholds, using configured values or defaults.
|
||||
#[cfg(feature = "ocr")]
|
||||
pub(crate) fn effective_thresholds(&self) -> OcrQualityThresholds {
|
||||
self.quality_thresholds.clone().unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Returns the effective pipeline config.
|
||||
///
|
||||
/// - If `pipeline` is explicitly set, returns it.
|
||||
/// - If `paddle-ocr` is compiled in and the backend is the default
|
||||
/// (tesseract), auto-constructs `[tesseract @ 100, paddleocr @ 50]`.
|
||||
/// - Otherwise returns `None` (single-backend mode).
|
||||
///
|
||||
/// Explicit non-default backend selections are honored as-is — a silent
|
||||
/// paddleocr fallback would mask errors from the chosen backend.
|
||||
#[cfg(feature = "ocr")]
|
||||
pub(crate) fn effective_pipeline(&self) -> Option<OcrPipelineConfig> {
|
||||
if self.pipeline.is_some() {
|
||||
return self.pipeline.clone();
|
||||
}
|
||||
|
||||
#[cfg(feature = "paddle-ocr")]
|
||||
{
|
||||
if self.backend != default_tesseract_backend() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let stages = vec![
|
||||
OcrPipelineStage {
|
||||
backend: self.backend.clone(),
|
||||
priority: 100,
|
||||
language: None,
|
||||
tesseract_config: self.tesseract_config.clone(),
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: self.vlm_config.clone(),
|
||||
backend_options: self.backend_options.clone(),
|
||||
},
|
||||
OcrPipelineStage {
|
||||
backend: "paddleocr".to_string(),
|
||||
priority: 50,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: self.paddle_ocr_config.clone(),
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
},
|
||||
];
|
||||
Some(OcrPipelineConfig {
|
||||
stages,
|
||||
quality_thresholds: self.effective_thresholds(),
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "paddle-ocr"))]
|
||||
{
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_ocr_enabled() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_tesseract_backend() -> String {
|
||||
"tesseract".to_string()
|
||||
}
|
||||
|
||||
fn default_eng() -> String {
|
||||
"eng".to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_ocr_config_default() {
|
||||
let config = OcrConfig::default();
|
||||
assert_eq!(config.backend, "tesseract");
|
||||
assert_eq!(config.language, "eng");
|
||||
assert!(config.tesseract_config.is_none());
|
||||
assert!(config.output_format.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_config_with_tesseract() {
|
||||
let config = OcrConfig {
|
||||
backend: "tesseract".to_string(),
|
||||
language: "fra".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.backend, "tesseract");
|
||||
assert_eq!(config.language, "fra");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tesseract_backend() {
|
||||
let config = OcrConfig {
|
||||
backend: "tesseract".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_easyocr_backend() {
|
||||
let config = OcrConfig {
|
||||
backend: "easyocr".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_paddleocr_backend() {
|
||||
let config = OcrConfig {
|
||||
backend: "paddleocr".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_invalid_backend_typo() {
|
||||
let config = OcrConfig {
|
||||
backend: "tesseract_typo".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
let result = config.validate();
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Invalid OCR backend"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_invalid_backend_completely_wrong() {
|
||||
let config = OcrConfig {
|
||||
backend: "ocr_lib".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
let result = config.validate();
|
||||
assert!(result.is_err());
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Invalid OCR backend") || err_msg.contains("Valid options are"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_default_backend() {
|
||||
let config = OcrConfig::default();
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
|
||||
// ── effective_pipeline tests ──
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_effective_pipeline_explicit_pipeline_returned_unchanged() {
|
||||
let explicit_pipeline = OcrPipelineConfig {
|
||||
stages: vec![OcrPipelineStage {
|
||||
backend: "easyocr".to_string(),
|
||||
priority: 200,
|
||||
language: Some("fra".to_string()),
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
}],
|
||||
quality_thresholds: OcrQualityThresholds::default(),
|
||||
};
|
||||
let config = OcrConfig {
|
||||
pipeline: Some(explicit_pipeline.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let result = config.effective_pipeline().unwrap();
|
||||
assert_eq!(result.stages.len(), 1);
|
||||
assert_eq!(result.stages[0].backend, "easyocr");
|
||||
assert_eq!(result.stages[0].priority, 200);
|
||||
assert_eq!(result.stages[0].language, Some("fra".to_string()));
|
||||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_effective_pipeline_explicit_paddleocr_no_autofallback() {
|
||||
let config = OcrConfig {
|
||||
backend: "paddleocr".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.effective_pipeline().is_none());
|
||||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_effective_pipeline_explicit_easyocr_no_autofallback() {
|
||||
let config = OcrConfig {
|
||||
backend: "easyocr".to_string(),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.effective_pipeline().is_none());
|
||||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_effective_pipeline_default_tesseract_backend() {
|
||||
let config = OcrConfig::default();
|
||||
let result = config.effective_pipeline();
|
||||
#[cfg(feature = "paddle-ocr")]
|
||||
{
|
||||
let pipeline = result.unwrap();
|
||||
assert_eq!(pipeline.stages.len(), 2);
|
||||
assert_eq!(pipeline.stages[0].backend, "tesseract");
|
||||
assert_eq!(pipeline.stages[0].priority, 100);
|
||||
assert_eq!(pipeline.stages[1].backend, "paddleocr");
|
||||
assert_eq!(pipeline.stages[1].priority, 50);
|
||||
}
|
||||
#[cfg(not(feature = "paddle-ocr"))]
|
||||
{
|
||||
assert!(result.is_none());
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_effective_thresholds_custom_vs_default() {
|
||||
// With custom thresholds
|
||||
let custom = OcrQualityThresholds {
|
||||
min_total_non_whitespace: 128,
|
||||
min_meaningful_words: 10,
|
||||
..Default::default()
|
||||
};
|
||||
let config_custom = OcrConfig {
|
||||
quality_thresholds: Some(custom.clone()),
|
||||
..Default::default()
|
||||
};
|
||||
let eff = config_custom.effective_thresholds();
|
||||
assert_eq!(eff.min_total_non_whitespace, 128);
|
||||
assert_eq!(eff.min_meaningful_words, 10);
|
||||
|
||||
// Without custom thresholds (should return defaults)
|
||||
let config_default = OcrConfig::default();
|
||||
let eff_default = config_default.effective_thresholds();
|
||||
assert_eq!(eff_default.min_total_non_whitespace, 64);
|
||||
assert_eq!(eff_default.min_meaningful_words, 3);
|
||||
}
|
||||
|
||||
// ── Serde tests ──
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_config_serde_roundtrip() {
|
||||
let pipeline = OcrPipelineConfig {
|
||||
stages: vec![
|
||||
OcrPipelineStage {
|
||||
backend: "tesseract".to_string(),
|
||||
priority: 100,
|
||||
language: Some("eng".to_string()),
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
},
|
||||
OcrPipelineStage {
|
||||
backend: "paddleocr".to_string(),
|
||||
priority: 50,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: Some(serde_json::json!({"use_gpu": false})),
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
},
|
||||
],
|
||||
quality_thresholds: OcrQualityThresholds::default(),
|
||||
};
|
||||
let json = serde_json::to_string(&pipeline).unwrap();
|
||||
let deserialized: OcrPipelineConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(deserialized.stages.len(), 2);
|
||||
assert_eq!(deserialized.stages[0].backend, "tesseract");
|
||||
assert_eq!(deserialized.stages[0].priority, 100);
|
||||
assert_eq!(deserialized.stages[1].backend, "paddleocr");
|
||||
assert_eq!(deserialized.stages[1].priority, 50);
|
||||
assert!(deserialized.stages[1].paddle_ocr_config.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_stage_deserialization_missing_optional_fields() {
|
||||
// Only backend is required; everything else should use defaults
|
||||
let json = r#"{"backend": "tesseract"}"#;
|
||||
let stage: OcrPipelineStage = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(stage.backend, "tesseract");
|
||||
assert_eq!(stage.priority, 100); // default_priority
|
||||
assert!(stage.language.is_none());
|
||||
assert!(stage.tesseract_config.is_none());
|
||||
assert!(stage.paddle_ocr_config.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_stage_default_priority_is_100() {
|
||||
let json = r#"{"backend": "easyocr"}"#;
|
||||
let stage: OcrPipelineStage = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(stage.priority, 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_config_deserialization_missing_optional_fields() {
|
||||
let json = r#"{}"#;
|
||||
let config: OcrConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.backend, "tesseract");
|
||||
assert_eq!(config.language, "eng");
|
||||
assert!(config.pipeline.is_none());
|
||||
assert!(config.quality_thresholds.is_none());
|
||||
assert!(config.element_config.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_quality_thresholds_deserialization_partial() {
|
||||
let json = r#"{"min_total_non_whitespace": 256}"#;
|
||||
let thresholds: OcrQualityThresholds = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(thresholds.min_total_non_whitespace, 256);
|
||||
// All other fields should be defaults
|
||||
assert_eq!(thresholds.min_meaningful_words, 3);
|
||||
assert_eq!(thresholds.min_garbage_chars, 5);
|
||||
assert!((thresholds.pipeline_min_quality - 0.5).abs() < f64::EPSILON);
|
||||
}
|
||||
|
||||
// ── Validation tests ──
|
||||
|
||||
#[test]
|
||||
fn test_validate_catches_invalid_pipeline_stage_backend() {
|
||||
let config = OcrConfig {
|
||||
pipeline: Some(OcrPipelineConfig {
|
||||
stages: vec![
|
||||
OcrPipelineStage {
|
||||
backend: "tesseract".to_string(),
|
||||
priority: 100,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
},
|
||||
OcrPipelineStage {
|
||||
backend: "invalid_backend".to_string(),
|
||||
priority: 50,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
},
|
||||
],
|
||||
quality_thresholds: OcrQualityThresholds::default(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let result = config.validate();
|
||||
assert!(result.is_err(), "Should catch invalid backend in pipeline stages");
|
||||
let err_msg = result.unwrap_err().to_string();
|
||||
assert!(err_msg.contains("Invalid OCR backend") || err_msg.contains("invalid_backend"));
|
||||
}
|
||||
|
||||
// ── backend_options tests ──
|
||||
|
||||
#[test]
|
||||
fn test_ocr_config_backend_options_default_is_none() {
|
||||
let config = OcrConfig::default();
|
||||
assert!(config.backend_options.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_config_backend_options_serde_roundtrip() {
|
||||
let config = OcrConfig {
|
||||
backend_options: Some(serde_json::json!({"mode": "fast", "threshold": 0.8, "enable_layout": true})),
|
||||
..Default::default()
|
||||
};
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: OcrConfig = serde_json::from_str(&json).unwrap();
|
||||
let opts = deserialized.backend_options.unwrap();
|
||||
assert_eq!(opts["mode"], "fast");
|
||||
assert!((opts["threshold"].as_f64().unwrap() - 0.8).abs() < f64::EPSILON);
|
||||
assert_eq!(opts["enable_layout"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ocr_config_backend_options_omitted_when_none() {
|
||||
let config = OcrConfig::default();
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
assert!(
|
||||
!json.contains("backend_options"),
|
||||
"backend_options must be omitted when None"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_stage_backend_options_serde_roundtrip() {
|
||||
let stage = OcrPipelineStage {
|
||||
backend: "custom".to_string(),
|
||||
priority: 80,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: Some(serde_json::json!({"batch_size": 4, "device": "cpu"})),
|
||||
};
|
||||
let json = serde_json::to_string(&stage).unwrap();
|
||||
let deserialized: OcrPipelineStage = serde_json::from_str(&json).unwrap();
|
||||
let opts = deserialized.backend_options.unwrap();
|
||||
assert_eq!(opts["batch_size"], 4);
|
||||
assert_eq!(opts["device"], "cpu");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pipeline_stage_backend_options_omitted_when_none() {
|
||||
let stage = OcrPipelineStage {
|
||||
backend: "tesseract".to_string(),
|
||||
priority: 100,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
};
|
||||
let json = serde_json::to_string(&stage).unwrap();
|
||||
assert!(
|
||||
!json.contains("backend_options"),
|
||||
"backend_options must be omitted when None"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(all(feature = "ocr", feature = "paddle-ocr"))]
|
||||
#[test]
|
||||
fn test_effective_pipeline_propagates_backend_options_to_primary_stage() {
|
||||
let config = OcrConfig {
|
||||
backend_options: Some(serde_json::json!({"mode": "fast"})),
|
||||
..Default::default()
|
||||
};
|
||||
let pipeline = config
|
||||
.effective_pipeline()
|
||||
.expect("paddle-ocr feature must produce a pipeline");
|
||||
assert_eq!(pipeline.stages.len(), 2);
|
||||
|
||||
let primary = &pipeline.stages[0];
|
||||
assert_eq!(primary.backend, "tesseract");
|
||||
let opts = primary
|
||||
.backend_options
|
||||
.as_ref()
|
||||
.expect("primary stage must carry backend_options");
|
||||
assert_eq!(opts["mode"], "fast");
|
||||
|
||||
// PaddleOCR stage should not inherit backend_options from the top-level config.
|
||||
let fallback = &pipeline.stages[1];
|
||||
assert_eq!(fallback.backend, "paddleocr");
|
||||
assert!(
|
||||
fallback.backend_options.is_none(),
|
||||
"paddleocr stage must not inherit backend_options"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_explicit_pipeline_ignores_top_level_backend_options() {
|
||||
// When the caller provides an explicit pipeline, OcrConfig.backend_options
|
||||
// must NOT be injected into the returned stages — the stage owns its own value.
|
||||
let config = OcrConfig {
|
||||
backend_options: Some(serde_json::json!({"mode": "fast"})),
|
||||
pipeline: Some(OcrPipelineConfig {
|
||||
stages: vec![OcrPipelineStage {
|
||||
backend: "tesseract".to_string(),
|
||||
priority: 100,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
}],
|
||||
quality_thresholds: OcrQualityThresholds::default(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let pipeline = config
|
||||
.effective_pipeline()
|
||||
.expect("explicit pipeline must be returned as-is");
|
||||
assert_eq!(pipeline.stages.len(), 1);
|
||||
assert!(
|
||||
pipeline.stages[0].backend_options.is_none(),
|
||||
"top-level backend_options must not be injected into an explicit pipeline stage"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(feature = "ocr")]
|
||||
#[test]
|
||||
fn test_stage_level_backend_options_preserved_in_explicit_pipeline() {
|
||||
// Stage-level backend_options in an explicit pipeline are returned unchanged —
|
||||
// neither cleared nor overridden by any top-level value.
|
||||
let stage_opts = serde_json::json!({"device": "gpu", "batch": 8});
|
||||
let config = OcrConfig {
|
||||
backend_options: Some(serde_json::json!({"mode": "fast"})),
|
||||
pipeline: Some(OcrPipelineConfig {
|
||||
stages: vec![OcrPipelineStage {
|
||||
backend: "custom".to_string(),
|
||||
priority: 100,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: Some(stage_opts.clone()),
|
||||
}],
|
||||
quality_thresholds: OcrQualityThresholds::default(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let pipeline = config
|
||||
.effective_pipeline()
|
||||
.expect("explicit pipeline must be returned as-is");
|
||||
let returned_opts = pipeline.stages[0]
|
||||
.backend_options
|
||||
.as_ref()
|
||||
.expect("stage-level backend_options must be preserved");
|
||||
assert_eq!(returned_opts["device"], "gpu");
|
||||
assert_eq!(returned_opts["batch"], 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_passes_with_valid_pipeline_stages() {
|
||||
let config = OcrConfig {
|
||||
pipeline: Some(OcrPipelineConfig {
|
||||
stages: vec![
|
||||
OcrPipelineStage {
|
||||
backend: "tesseract".to_string(),
|
||||
priority: 100,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
},
|
||||
OcrPipelineStage {
|
||||
backend: "paddleocr".to_string(),
|
||||
priority: 50,
|
||||
language: None,
|
||||
tesseract_config: None,
|
||||
paddle_ocr_config: None,
|
||||
vlm_config: None,
|
||||
backend_options: None,
|
||||
},
|
||||
],
|
||||
quality_thresholds: OcrQualityThresholds::default(),
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
assert!(config.validate().is_ok());
|
||||
}
|
||||
}
|
||||
57
crates/kreuzberg/src/core/config/page.rs
Normal file
57
crates/kreuzberg/src/core/config/page.rs
Normal file
@@ -0,0 +1,57 @@
|
||||
//! Page extraction and tracking configuration.
|
||||
//!
|
||||
//! Controls how pages are extracted, tracked, and represented in extraction results.
|
||||
//! When `None`, page tracking is disabled.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// Page extraction and tracking configuration.
|
||||
///
|
||||
/// Controls how pages are extracted, tracked, and represented in the extraction results.
|
||||
/// When `None`, page tracking is disabled.
|
||||
///
|
||||
/// Page range tracking in chunk metadata (first_page/last_page) is automatically enabled
|
||||
/// when page boundaries are available and chunking is configured.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct PageConfig {
|
||||
/// Extract pages as separate array (ExtractionResult.pages)
|
||||
#[serde(default)]
|
||||
pub extract_pages: bool,
|
||||
|
||||
/// Insert page markers in main content string
|
||||
#[serde(default)]
|
||||
pub insert_page_markers: bool,
|
||||
|
||||
/// Page marker format (use {page_num} placeholder)
|
||||
/// Default: "\n\n<!-- PAGE {page_num} -->\n\n"
|
||||
#[serde(default = "default_page_marker_format")]
|
||||
pub marker_format: String,
|
||||
}
|
||||
|
||||
impl Default for PageConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
extract_pages: false,
|
||||
insert_page_markers: false,
|
||||
marker_format: "\n\n<!-- PAGE {page_num} -->\n\n".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_page_marker_format() -> String {
|
||||
"\n\n<!-- PAGE {page_num} -->\n\n".to_string()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_page_config_default() {
|
||||
let config = PageConfig::default();
|
||||
assert!(!config.extract_pages);
|
||||
assert!(!config.insert_page_markers);
|
||||
assert_eq!(config.marker_format, "\n\n<!-- PAGE {page_num} -->\n\n");
|
||||
}
|
||||
}
|
||||
192
crates/kreuzberg/src/core/config/pdf.rs
Normal file
192
crates/kreuzberg/src/core/config/pdf.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! PDF-specific configuration.
|
||||
//!
|
||||
//! Defines PDF extraction options including metadata handling, image extraction,
|
||||
//! password management, and hierarchy extraction for document structure analysis.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// PDF-specific configuration.
|
||||
#[cfg(feature = "pdf")]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PdfConfig {
|
||||
/// Extract images from PDF
|
||||
#[serde(default)]
|
||||
pub extract_images: bool,
|
||||
|
||||
/// Extract tables from PDF.
|
||||
///
|
||||
/// When `true` (default), runs pdf_oxide's native grid detector and, if it
|
||||
/// finds nothing, falls back to the heuristic text-layer reconstruction in
|
||||
/// `pdf::oxide::table::extract_tables_heuristic`. Set to `false` to skip
|
||||
/// both passes — `tables` will then be empty in the result.
|
||||
#[serde(default = "default_true")]
|
||||
pub extract_tables: bool,
|
||||
|
||||
/// List of passwords to try when opening encrypted PDFs
|
||||
#[serde(default)]
|
||||
pub passwords: Option<Vec<String>>,
|
||||
|
||||
/// Extract PDF metadata
|
||||
#[serde(default = "default_true")]
|
||||
pub extract_metadata: bool,
|
||||
|
||||
/// Hierarchy extraction configuration (None = hierarchy extraction disabled)
|
||||
#[serde(default)]
|
||||
pub hierarchy: Option<HierarchyConfig>,
|
||||
|
||||
/// Extract PDF annotations (text notes, highlights, links, stamps).
|
||||
/// Default: false
|
||||
#[serde(default)]
|
||||
pub extract_annotations: bool,
|
||||
|
||||
/// Top margin fraction (0.0–1.0) of page height to exclude headers/running heads.
|
||||
/// Default: 0.06 (6%)
|
||||
#[serde(default)]
|
||||
pub top_margin_fraction: Option<f32>,
|
||||
|
||||
/// Bottom margin fraction (0.0–1.0) of page height to exclude footers/page numbers.
|
||||
/// Default: 0.05 (5%)
|
||||
#[serde(default)]
|
||||
pub bottom_margin_fraction: Option<f32>,
|
||||
|
||||
/// Allow single-column pseudo tables in extraction results.
|
||||
///
|
||||
/// By default, tables with fewer than 2 columns (layout-guided) or 3 columns
|
||||
/// (heuristic) are rejected. When `true`, the minimum column count is relaxed
|
||||
/// to 1, allowing single-column structured data (glossaries, itemized lists)
|
||||
/// to be emitted as tables. Other quality filters (density, sparsity, prose
|
||||
/// detection) still apply.
|
||||
#[serde(default)]
|
||||
pub allow_single_column_tables: bool,
|
||||
|
||||
/// Perform OCR on inline images extracted from PDF pages and attach the
|
||||
/// recognized text to each `ExtractedImage.ocr_result`. Requires Tesseract
|
||||
/// to be available; if `ExtractionConfig.ocr` is `None` the extractor
|
||||
/// falls back to `TesseractConfig::default()`. Per-image failures degrade
|
||||
/// gracefully (the image is returned without OCR text rather than failing
|
||||
/// the whole extraction). Default: `false`.
|
||||
#[serde(default)]
|
||||
pub ocr_inline_images: bool,
|
||||
}
|
||||
|
||||
/// Hierarchy extraction configuration for PDF text structure analysis.
|
||||
///
|
||||
/// Enables extraction of document hierarchy levels (H1-H6) based on font size
|
||||
/// clustering and semantic analysis. When enabled, hierarchical blocks are
|
||||
/// included in page content.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HierarchyConfig {
|
||||
/// Enable hierarchy extraction
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Number of font size clusters to use for hierarchy levels (1-7)
|
||||
///
|
||||
/// Default: 6, which provides H1-H6 heading levels with body text.
|
||||
/// Larger values create more fine-grained hierarchy levels.
|
||||
#[serde(default = "default_k_clusters")]
|
||||
pub k_clusters: usize,
|
||||
|
||||
/// Include bounding box information in hierarchy blocks
|
||||
#[serde(default = "default_true")]
|
||||
pub include_bbox: bool,
|
||||
|
||||
/// OCR coverage threshold for smart OCR triggering (0.0-1.0)
|
||||
///
|
||||
/// Determines when OCR should be triggered based on text block coverage.
|
||||
/// OCR is triggered when text blocks cover less than this fraction of the page.
|
||||
/// Default: 0.5 (trigger OCR if less than 50% of page has text)
|
||||
#[serde(default = "default_ocr_coverage_threshold")]
|
||||
pub ocr_coverage_threshold: Option<f32>,
|
||||
}
|
||||
|
||||
#[cfg(feature = "pdf")]
|
||||
impl Default for PdfConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
extract_images: false,
|
||||
extract_tables: true,
|
||||
passwords: None,
|
||||
extract_metadata: true,
|
||||
hierarchy: None,
|
||||
extract_annotations: false,
|
||||
top_margin_fraction: None,
|
||||
bottom_margin_fraction: None,
|
||||
allow_single_column_tables: false,
|
||||
ocr_inline_images: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for HierarchyConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
k_clusters: 3,
|
||||
include_bbox: true,
|
||||
ocr_coverage_threshold: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_k_clusters() -> usize {
|
||||
3
|
||||
}
|
||||
|
||||
fn default_ocr_coverage_threshold() -> Option<f32> {
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
#[test]
|
||||
#[cfg(feature = "pdf")]
|
||||
fn test_hierarchy_config_default() {
|
||||
use super::*;
|
||||
let config = HierarchyConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert_eq!(config.k_clusters, 3);
|
||||
assert!(config.include_bbox);
|
||||
assert!(config.ocr_coverage_threshold.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "pdf")]
|
||||
fn test_hierarchy_config_disabled() {
|
||||
use super::*;
|
||||
let config = HierarchyConfig {
|
||||
enabled: false,
|
||||
k_clusters: 3,
|
||||
include_bbox: false,
|
||||
ocr_coverage_threshold: Some(0.7),
|
||||
};
|
||||
assert!(!config.enabled);
|
||||
assert_eq!(config.k_clusters, 3);
|
||||
assert!(!config.include_bbox);
|
||||
assert_eq!(config.ocr_coverage_threshold, Some(0.7));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "pdf")]
|
||||
fn test_pdf_config_custom_margins() {
|
||||
use super::*;
|
||||
let config = PdfConfig {
|
||||
extract_images: false,
|
||||
extract_tables: true,
|
||||
passwords: None,
|
||||
extract_metadata: true,
|
||||
hierarchy: None,
|
||||
extract_annotations: false,
|
||||
top_margin_fraction: Some(0.10),
|
||||
bottom_margin_fraction: Some(0.08),
|
||||
allow_single_column_tables: false,
|
||||
ocr_inline_images: false,
|
||||
};
|
||||
assert_eq!(config.top_margin_fraction, Some(0.10));
|
||||
assert_eq!(config.bottom_margin_fraction, Some(0.08));
|
||||
}
|
||||
}
|
||||
838
crates/kreuzberg/src/core/config/processing.rs
Normal file
838
crates/kreuzberg/src/core/config/processing.rs
Normal file
@@ -0,0 +1,838 @@
|
||||
//! Post-processing and chunking configuration.
|
||||
//!
|
||||
//! Defines configuration for post-processing pipelines, text chunking,
|
||||
//! and embedding generation.
|
||||
|
||||
use ahash::AHashSet;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Type of text chunker to use.
|
||||
///
|
||||
/// # Variants
|
||||
///
|
||||
/// * `Text` - Generic text splitter, splits on whitespace and punctuation
|
||||
/// * `Markdown` - Markdown-aware splitter, preserves formatting and structure
|
||||
/// * `Yaml` - YAML-aware splitter, creates one chunk per top-level key
|
||||
/// * `Semantic` - Topic-aware chunker. With an `EmbeddingConfig`, splits at
|
||||
/// embedding-based topic shifts tuned by `topic_threshold` (default 0.75,
|
||||
/// lower = more splits). Without an embedding, falls back to a
|
||||
/// structural-boundary heuristic (ALL-CAPS headers, numbered sections,
|
||||
/// blank-line paragraphs) and merges groups into chunks capped at
|
||||
/// `max_characters` (default 1000). `topic_threshold` has no effect in the
|
||||
/// fallback path. For best results, pair with an embedding model.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum ChunkerType {
|
||||
#[default]
|
||||
Text,
|
||||
Markdown,
|
||||
Yaml,
|
||||
Semantic,
|
||||
}
|
||||
|
||||
/// How chunk size is measured.
|
||||
///
|
||||
/// Defaults to `Characters` (Unicode character count). When using token-based sizing,
|
||||
/// chunks are sized by token count according to the specified tokenizer.
|
||||
///
|
||||
/// Token-based sizing uses HuggingFace tokenizers loaded at runtime. Any tokenizer
|
||||
/// available on HuggingFace Hub can be used, including OpenAI-compatible tokenizers
|
||||
/// (e.g., `Xenova/gpt-4o`, `Xenova/cl100k_base`).
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum ChunkSizing {
|
||||
/// Size measured in Unicode characters (default).
|
||||
#[default]
|
||||
Characters,
|
||||
/// Size measured in tokens from a HuggingFace tokenizer.
|
||||
#[cfg(feature = "chunking-tokenizers")]
|
||||
Tokenizer {
|
||||
/// HuggingFace model ID or path, e.g. "Xenova/gpt-4o", "bert-base-uncased".
|
||||
model: String,
|
||||
/// Optional cache directory override for tokenizer files.
|
||||
/// Defaults to hf-hub's standard cache (`~/.cache/huggingface/`).
|
||||
/// Can also be set via `KREUZBERG_TOKENIZER_CACHE_DIR` environment variable.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
cache_dir: Option<std::path::PathBuf>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Post-processor configuration.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PostProcessorConfig {
|
||||
/// Enable post-processors
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Whitelist of processor names to run (None = all enabled)
|
||||
#[serde(default)]
|
||||
pub enabled_processors: Option<Vec<String>>,
|
||||
|
||||
/// Blacklist of processor names to skip (None = none disabled)
|
||||
#[serde(default)]
|
||||
pub disabled_processors: Option<Vec<String>>,
|
||||
|
||||
/// Pre-computed AHashSet for O(1) enabled processor lookup
|
||||
#[serde(skip)]
|
||||
pub enabled_set: Option<AHashSet<String>>,
|
||||
|
||||
/// Pre-computed AHashSet for O(1) disabled processor lookup
|
||||
#[serde(skip)]
|
||||
pub disabled_set: Option<AHashSet<String>>,
|
||||
}
|
||||
|
||||
impl PostProcessorConfig {
|
||||
/// Pre-compute HashSets for O(1) processor name lookups.
|
||||
///
|
||||
/// This method converts the enabled/disabled processor Vec to HashSet
|
||||
/// for constant-time lookups in the pipeline.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn build_lookup_sets(&mut self) {
|
||||
if let Some(ref enabled) = self.enabled_processors {
|
||||
self.enabled_set = Some(enabled.iter().cloned().collect());
|
||||
}
|
||||
if let Some(ref disabled) = self.disabled_processors {
|
||||
self.disabled_set = Some(disabled.iter().cloned().collect());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for PostProcessorConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
enabled_processors: None,
|
||||
disabled_processors: None,
|
||||
enabled_set: None,
|
||||
disabled_set: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Chunking configuration.
|
||||
///
|
||||
/// Configures text chunking for document content, including chunk size,
|
||||
/// overlap, trimming behavior, and optional embeddings.
|
||||
///
|
||||
/// Use `..Default::default()` when constructing to allow for future field additions:
|
||||
/// ```rust
|
||||
/// # use kreuzberg::ChunkingConfig;
|
||||
/// let config = ChunkingConfig {
|
||||
/// max_characters: 500,
|
||||
/// ..Default::default()
|
||||
/// };
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChunkingConfig {
|
||||
/// Maximum size per chunk (in units determined by `sizing`).
|
||||
///
|
||||
/// When `sizing` is `Characters` (default), this is the max character count.
|
||||
/// When using token-based sizing, this is the max token count.
|
||||
///
|
||||
/// Default: 1000
|
||||
#[serde(default = "default_chunk_size", rename = "max_chars", alias = "max_characters")]
|
||||
pub max_characters: usize,
|
||||
|
||||
/// Overlap between chunks (in units determined by `sizing`).
|
||||
///
|
||||
/// Default: 200
|
||||
#[serde(default = "default_chunk_overlap", rename = "max_overlap", alias = "overlap")]
|
||||
pub overlap: usize,
|
||||
|
||||
/// Whether to trim whitespace from chunk boundaries.
|
||||
///
|
||||
/// Default: true
|
||||
#[serde(default = "default_trim")]
|
||||
pub trim: bool,
|
||||
|
||||
/// Type of chunker to use (Text or Markdown).
|
||||
///
|
||||
/// Default: Text
|
||||
#[serde(default = "default_chunker_type")]
|
||||
pub chunker_type: ChunkerType,
|
||||
|
||||
/// Optional embedding configuration for chunk embeddings.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub embedding: Option<EmbeddingConfig>,
|
||||
|
||||
/// Use a preset configuration (overrides individual settings if provided).
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub preset: Option<String>,
|
||||
|
||||
/// How to measure chunk size.
|
||||
///
|
||||
/// Default: `Characters` (Unicode character count).
|
||||
/// Enable `chunking-tiktoken` or `chunking-tokenizers` features for token-based sizing.
|
||||
#[serde(default, deserialize_with = "deserialize_null_default")]
|
||||
pub sizing: ChunkSizing,
|
||||
|
||||
/// When `true` and `chunker_type` is `Markdown`, prepend the heading hierarchy
|
||||
/// path (e.g. `"# Title > ## Section\n\n"`) to each chunk's content string.
|
||||
///
|
||||
/// This is useful for RAG pipelines where each chunk needs self-contained
|
||||
/// context about its position in the document structure.
|
||||
///
|
||||
/// Default: `false`
|
||||
#[serde(default)]
|
||||
pub prepend_heading_context: bool,
|
||||
|
||||
/// Optional cosine similarity threshold for semantic topic boundary detection.
|
||||
///
|
||||
/// Only used when `chunker_type` is `Semantic` and an `EmbeddingConfig` is
|
||||
/// provided. You almost never need to set this. When omitted, defaults to
|
||||
/// `0.75` which works well for most documents. Lower values detect more
|
||||
/// topic boundaries (more, smaller chunks); higher values detect fewer.
|
||||
/// Range: `0.0..=1.0`.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub topic_threshold: Option<f32>,
|
||||
}
|
||||
|
||||
impl ChunkingConfig {
|
||||
/// Set the cosine similarity threshold for semantic topic boundary detection.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics if `threshold` is outside `[0.0, 1.0]`.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn with_topic_threshold(mut self, threshold: f32) -> Self {
|
||||
assert!(
|
||||
(0.0..=1.0).contains(&threshold),
|
||||
"topic_threshold must be in [0.0, 1.0], got {threshold}"
|
||||
);
|
||||
self.topic_threshold = Some(threshold);
|
||||
self
|
||||
}
|
||||
|
||||
/// Resolve a preset name into concrete chunking and embedding configuration.
|
||||
///
|
||||
/// When `preset` is set (e.g., `"balanced"`), this overrides `max_characters` and
|
||||
/// `overlap` from the preset definition, and configures the embedding model if
|
||||
/// no embedding config was explicitly provided.
|
||||
///
|
||||
/// If the preset name is not recognized, a warning is logged and the config
|
||||
/// is returned unchanged.
|
||||
///
|
||||
/// Requires the `embeddings` feature. Without it, this is a no-op that returns
|
||||
/// the config unchanged.
|
||||
#[cfg(feature = "embeddings")]
|
||||
pub(crate) fn resolve_preset(&self) -> Self {
|
||||
let preset_name = match &self.preset {
|
||||
Some(name) => name,
|
||||
None => return self.clone(),
|
||||
};
|
||||
|
||||
let preset = match crate::embeddings::get_preset(preset_name) {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
tracing::warn!(
|
||||
"Unknown chunking preset '{}', using manual config. Available: {:?}",
|
||||
preset_name,
|
||||
crate::embeddings::list_presets()
|
||||
);
|
||||
return self.clone();
|
||||
}
|
||||
};
|
||||
|
||||
// Preserve the caller's embedding choice, including None.
|
||||
// Presets configure chunking parameters only; users must explicitly
|
||||
// provide an EmbeddingConfig to opt into embedding generation.
|
||||
let embedding = self.embedding.clone();
|
||||
|
||||
Self {
|
||||
max_characters: preset.chunk_size,
|
||||
overlap: preset.overlap,
|
||||
embedding,
|
||||
// Preserve caller's other settings
|
||||
trim: self.trim,
|
||||
chunker_type: self.chunker_type,
|
||||
preset: self.preset.clone(),
|
||||
sizing: self.sizing.clone(),
|
||||
prepend_heading_context: self.prepend_heading_context,
|
||||
topic_threshold: self.topic_threshold,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve a preset name (no-op without the `embeddings` feature).
|
||||
#[cfg(all(feature = "chunking", not(feature = "embeddings")))]
|
||||
pub(crate) fn resolve_preset(&self) -> Self {
|
||||
if self.preset.is_some() {
|
||||
tracing::warn!("Chunking presets require the 'embeddings' feature");
|
||||
}
|
||||
self.clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for ChunkingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_characters: 1000,
|
||||
overlap: 200,
|
||||
trim: true,
|
||||
chunker_type: ChunkerType::Text,
|
||||
embedding: None,
|
||||
preset: None,
|
||||
sizing: ChunkSizing::default(),
|
||||
prepend_heading_context: false,
|
||||
topic_threshold: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding configuration for text chunks.
|
||||
///
|
||||
/// Configures embedding generation using ONNX models via the vendored embedding engine.
|
||||
/// Requires the `embeddings` feature to be enabled.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct EmbeddingConfig {
|
||||
/// The embedding model to use (defaults to "balanced" preset if not specified)
|
||||
#[serde(default = "default_model", deserialize_with = "deserialize_null_model")]
|
||||
pub model: EmbeddingModelType,
|
||||
|
||||
/// Whether to normalize embedding vectors (recommended for cosine similarity)
|
||||
#[serde(default = "default_normalize")]
|
||||
pub normalize: bool,
|
||||
|
||||
/// Batch size for embedding generation
|
||||
#[serde(default = "default_batch_size")]
|
||||
pub batch_size: usize,
|
||||
|
||||
/// Show model download progress
|
||||
#[serde(default)]
|
||||
pub show_download_progress: bool,
|
||||
|
||||
/// Custom cache directory for model files
|
||||
///
|
||||
/// Defaults to `~/.cache/kreuzberg/embeddings/` if not specified.
|
||||
/// Allows full customization of model download location.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cache_dir: Option<PathBuf>,
|
||||
|
||||
/// Hardware acceleration for the embedding ONNX model.
|
||||
///
|
||||
/// When set, controls which execution provider (CPU, CUDA, CoreML, TensorRT)
|
||||
/// is used for inference. Defaults to `None` (auto-select per platform).
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub acceleration: Option<super::acceleration::AccelerationConfig>,
|
||||
|
||||
/// Maximum wall-clock duration (in seconds) for a single `embed()` call when
|
||||
/// using [`EmbeddingModelType::Plugin`].
|
||||
///
|
||||
/// Applies only to the in-process plugin path — protects against hung
|
||||
/// host-language backends (e.g. a Python callback deadlocked on the GIL,
|
||||
/// a model stuck on CUDA OOM retries, etc.). On timeout, the dispatcher
|
||||
/// returns [`crate::KreuzbergError::Plugin`] instead of blocking forever.
|
||||
///
|
||||
/// `None` disables the timeout. The default (60 seconds) is conservative
|
||||
/// for common in-process inference; increase for large batches on slow
|
||||
/// hardware.
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
pub max_embed_duration_secs: Option<u64>,
|
||||
}
|
||||
|
||||
impl Default for EmbeddingConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
model: EmbeddingModelType::Preset {
|
||||
name: "balanced".to_string(),
|
||||
},
|
||||
normalize: true,
|
||||
batch_size: 32,
|
||||
show_download_progress: false,
|
||||
cache_dir: None,
|
||||
acceleration: None,
|
||||
max_embed_duration_secs: Some(60),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Embedding model types supported by Kreuzberg.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", rename_all = "snake_case")]
|
||||
pub enum EmbeddingModelType {
|
||||
/// Use a preset model configuration (recommended)
|
||||
Preset { name: String },
|
||||
|
||||
/// Use a custom ONNX model from HuggingFace
|
||||
Custom { model_id: String, dimensions: usize },
|
||||
|
||||
/// Provider-hosted embedding model via liter-llm.
|
||||
///
|
||||
/// Uses the model specified in the nested `LlmConfig` (e.g.,
|
||||
/// `"openai/text-embedding-3-small"`).
|
||||
Llm { llm: super::llm::LlmConfig },
|
||||
|
||||
/// In-process embedding backend registered via the plugin system.
|
||||
///
|
||||
/// The caller registers an [`EmbeddingBackend`](crate::plugins::EmbeddingBackend) once
|
||||
/// (e.g. a wrapper around an already-loaded `llama-cpp-python`, `sentence-transformers`,
|
||||
/// or tuned ONNX model), then references it by name in config. Kreuzberg calls back
|
||||
/// into the registered backend during chunking and standalone embed requests —
|
||||
/// no HuggingFace download, no ONNX Runtime requirement, no HTTP sidecar.
|
||||
///
|
||||
/// When this variant is selected, only the following [`EmbeddingConfig`] fields
|
||||
/// apply: `normalize` (post-call L2 normalization) and `max_embed_duration_secs`
|
||||
/// (dispatcher timeout). Model-loading fields (`batch_size`, `cache_dir`,
|
||||
/// `show_download_progress`, `acceleration`) are ignored — the host owns the
|
||||
/// model lifecycle.
|
||||
///
|
||||
/// Semantic chunking falls back to [`ChunkingConfig::max_characters`] when this variant
|
||||
/// is used, since there is no preset to look a chunk-size ceiling up against — size your
|
||||
/// context window via `max_characters` directly.
|
||||
///
|
||||
/// See [`crate::plugins::register_embedding_backend`].
|
||||
Plugin { name: String },
|
||||
}
|
||||
|
||||
impl Default for EmbeddingModelType {
|
||||
/// Returns the "balanced" preset as the default model.
|
||||
///
|
||||
/// Previously returned `Preset { name: "" }` (empty string) which caused
|
||||
/// "Unknown embedding preset: " errors in every language binding that calls
|
||||
/// `EmbeddingModelType::default()` — including generated bindings that
|
||||
/// use struct-level `#[serde(default)]` instead of `default_model()`.
|
||||
/// All defaults across the codebase converge on "balanced".
|
||||
fn default() -> Self {
|
||||
Self::Preset {
|
||||
name: "balanced".to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_chunk_size() -> usize {
|
||||
1000
|
||||
}
|
||||
|
||||
/// Deserialize a value that may be explicitly `null` into its `Default` value.
|
||||
///
|
||||
/// Internally-tagged serde enums (e.g. `#[serde(tag = "type")]`) reject `null`
|
||||
/// even when the containing field has `#[serde(default)]`, because that attribute
|
||||
/// only covers the *missing* case. Polyglot bindings frequently emit explicit
|
||||
/// `"field": null` from zero-valued mirror structs, so this helper accepts either
|
||||
/// `null` or a present value and falls back to `T::default()` for null.
|
||||
fn deserialize_null_default<'de, D, T>(deserializer: D) -> Result<T, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
T: Default + serde::Deserialize<'de>,
|
||||
{
|
||||
let opt = Option::<T>::deserialize(deserializer)?;
|
||||
Ok(opt.unwrap_or_default())
|
||||
}
|
||||
|
||||
fn default_chunk_overlap() -> usize {
|
||||
200
|
||||
}
|
||||
|
||||
fn default_trim() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_chunker_type() -> ChunkerType {
|
||||
ChunkerType::Text
|
||||
}
|
||||
|
||||
fn default_normalize() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
fn default_batch_size() -> usize {
|
||||
32
|
||||
}
|
||||
|
||||
fn default_model() -> EmbeddingModelType {
|
||||
EmbeddingModelType::Preset {
|
||||
name: "balanced".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// `deserialize_with` companion for `EmbeddingModelType` fields that may be
|
||||
/// explicitly `null` in polyglot binding payloads. Treats null as the configured
|
||||
/// `default_model()` (the "balanced" preset) rather than the trait `Default` impl
|
||||
/// (which is an empty-name placeholder unsuitable for live use).
|
||||
fn deserialize_null_model<'de, D>(deserializer: D) -> Result<EmbeddingModelType, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let opt = Option::<EmbeddingModelType>::deserialize(deserializer)?;
|
||||
Ok(opt.unwrap_or_else(default_model))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_postprocessor_config_default() {
|
||||
let config = PostProcessorConfig::default();
|
||||
assert!(config.enabled);
|
||||
assert!(config.enabled_processors.is_none());
|
||||
assert!(config.disabled_processors.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_postprocessor_config_build_lookup_sets() {
|
||||
let mut config = PostProcessorConfig {
|
||||
enabled: true,
|
||||
enabled_processors: Some(vec!["a".to_string(), "b".to_string()]),
|
||||
disabled_processors: Some(vec!["c".to_string()]),
|
||||
enabled_set: None,
|
||||
disabled_set: None,
|
||||
};
|
||||
|
||||
config.build_lookup_sets();
|
||||
|
||||
assert!(config.enabled_set.is_some());
|
||||
assert!(config.disabled_set.is_some());
|
||||
assert!(config.enabled_set.unwrap().contains("a"));
|
||||
assert!(config.disabled_set.unwrap().contains("c"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chunking_config_defaults() {
|
||||
let config = ChunkingConfig::default();
|
||||
assert_eq!(config.max_characters, 1000);
|
||||
assert_eq!(config.overlap, 200);
|
||||
assert!(config.trim);
|
||||
assert_eq!(config.chunker_type, ChunkerType::Text);
|
||||
assert!(matches!(config.sizing, ChunkSizing::Characters));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_config_default() {
|
||||
let config = EmbeddingConfig::default();
|
||||
assert!(config.normalize);
|
||||
assert_eq!(config.batch_size, 32);
|
||||
assert!(config.cache_dir.is_none());
|
||||
}
|
||||
|
||||
/// Tests that `EmbeddingModelType::default()` returns the "balanced" preset.
|
||||
///
|
||||
/// Language bindings that use struct-level `#[serde(default)]` resolve absent
|
||||
/// `model` fields via this impl. An empty-string name caused "Unknown embedding
|
||||
/// preset: " panics in `get_preset()`; the default must be a valid preset.
|
||||
#[test]
|
||||
fn test_embedding_model_type_default_is_balanced() {
|
||||
match EmbeddingModelType::default() {
|
||||
EmbeddingModelType::Preset { name } => {
|
||||
assert_eq!(name, "balanced", "Default model should be the balanced preset");
|
||||
}
|
||||
other => panic!("Expected Preset variant, got {:?}", other),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests that EmbeddingModelType::Preset serializes with "type" field (internally-tagged).
|
||||
/// This validates the API schema matches the documented format:
|
||||
/// `{"type": "preset", "name": "fast"}` NOT `{"preset": {"name": "fast"}}`
|
||||
#[test]
|
||||
fn test_embedding_model_type_preset_serialization() {
|
||||
let model = EmbeddingModelType::Preset {
|
||||
name: "fast".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&model).unwrap();
|
||||
|
||||
// Should use internally-tagged format with "type" discriminator
|
||||
assert!(json.contains(r#""type":"preset""#), "Should contain type:preset field");
|
||||
assert!(json.contains(r#""name":"fast""#), "Should contain name:fast field");
|
||||
|
||||
// Should NOT use adjacently-tagged format
|
||||
assert!(
|
||||
!json.contains(r#"{"preset":"#),
|
||||
"Should NOT use adjacently-tagged format"
|
||||
);
|
||||
}
|
||||
|
||||
/// Tests that EmbeddingModelType::Preset deserializes from the documented API format.
|
||||
/// API documentation shows: `{"type": "preset", "name": "fast"}`
|
||||
#[test]
|
||||
fn test_embedding_model_type_preset_deserialization() {
|
||||
// This is the documented API format that users should send
|
||||
let json = r#"{"type": "preset", "name": "fast"}"#;
|
||||
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
|
||||
|
||||
match model {
|
||||
EmbeddingModelType::Preset { name } => {
|
||||
assert_eq!(name, "fast");
|
||||
}
|
||||
_ => panic!("Expected Preset variant"),
|
||||
}
|
||||
}
|
||||
|
||||
/// Tests that the wrong format (adjacently-tagged) is rejected.
|
||||
/// This ensures the API doesn't accept the old/wrong documentation format.
|
||||
#[test]
|
||||
fn test_embedding_model_type_rejects_wrong_format() {
|
||||
// This is the WRONG format that was in the old documentation
|
||||
let wrong_json = r#"{"preset": {"name": "fast"}}"#;
|
||||
let result: Result<EmbeddingModelType, _> = serde_json::from_str(wrong_json);
|
||||
|
||||
// Should fail to parse - the wrong format should be rejected
|
||||
assert!(result.is_err(), "Should reject adjacently-tagged format");
|
||||
}
|
||||
|
||||
/// Tests round-trip serialization/deserialization of EmbeddingConfig.
|
||||
#[test]
|
||||
fn test_embedding_config_roundtrip() {
|
||||
let config = EmbeddingConfig {
|
||||
model: EmbeddingModelType::Preset {
|
||||
name: "balanced".to_string(),
|
||||
},
|
||||
normalize: true,
|
||||
batch_size: 64,
|
||||
show_download_progress: false,
|
||||
cache_dir: None,
|
||||
acceleration: None,
|
||||
max_embed_duration_secs: Some(60),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
let deserialized: EmbeddingConfig = serde_json::from_str(&json).unwrap();
|
||||
|
||||
match deserialized.model {
|
||||
EmbeddingModelType::Preset { name } => {
|
||||
assert_eq!(name, "balanced");
|
||||
}
|
||||
_ => panic!("Expected Preset variant"),
|
||||
}
|
||||
assert!(deserialized.normalize);
|
||||
assert_eq!(deserialized.batch_size, 64);
|
||||
}
|
||||
|
||||
/// Tests Custom model type serialization format.
|
||||
#[test]
|
||||
fn test_embedding_model_type_custom_serialization() {
|
||||
let model = EmbeddingModelType::Custom {
|
||||
model_id: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
|
||||
dimensions: 384,
|
||||
};
|
||||
let json = serde_json::to_string(&model).unwrap();
|
||||
|
||||
assert!(json.contains(r#""type":"custom""#), "Should contain type:custom field");
|
||||
assert!(json.contains(r#""model_id":"#), "Should contain model_id field");
|
||||
assert!(json.contains(r#""dimensions":384"#), "Should contain dimensions field");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "embeddings")]
|
||||
fn test_resolve_preset_balanced() {
|
||||
let config = ChunkingConfig {
|
||||
preset: Some("balanced".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = config.resolve_preset();
|
||||
assert_eq!(resolved.max_characters, 1024);
|
||||
assert_eq!(resolved.overlap, 100);
|
||||
// Preset configures chunking parameters only; embedding stays None unless
|
||||
// the caller explicitly provided one (#797).
|
||||
assert!(resolved.embedding.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[cfg(feature = "embeddings")]
|
||||
fn test_resolve_preset_preserves_explicit_embedding() {
|
||||
let explicit_embedding = EmbeddingConfig {
|
||||
model: EmbeddingModelType::Custom {
|
||||
model_id: "custom/model".to_string(),
|
||||
dimensions: 512,
|
||||
},
|
||||
batch_size: 64,
|
||||
..Default::default()
|
||||
};
|
||||
let config = ChunkingConfig {
|
||||
preset: Some("fast".to_string()),
|
||||
embedding: Some(explicit_embedding),
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = config.resolve_preset();
|
||||
assert_eq!(resolved.max_characters, 512);
|
||||
assert_eq!(resolved.overlap, 50);
|
||||
// Explicit embedding config preserved
|
||||
match &resolved.embedding.unwrap().model {
|
||||
EmbeddingModelType::Custom { model_id, .. } => assert_eq!(model_id, "custom/model"),
|
||||
_ => panic!("Expected Custom model type to be preserved"),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "embeddings", feature = "chunking"))]
|
||||
#[test]
|
||||
fn test_resolve_preset_no_preset_returns_unchanged() {
|
||||
let config = ChunkingConfig {
|
||||
max_characters: 500,
|
||||
overlap: 50,
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = config.resolve_preset();
|
||||
assert_eq!(resolved.max_characters, 500);
|
||||
assert_eq!(resolved.overlap, 50);
|
||||
assert!(resolved.embedding.is_none());
|
||||
}
|
||||
|
||||
#[cfg(any(feature = "embeddings", feature = "chunking"))]
|
||||
#[test]
|
||||
fn test_resolve_preset_unknown_name_returns_unchanged() {
|
||||
let config = ChunkingConfig {
|
||||
max_characters: 500,
|
||||
preset: Some("nonexistent".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = config.resolve_preset();
|
||||
assert_eq!(resolved.max_characters, 500);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_model_type_llm_roundtrip() {
|
||||
let model_type = EmbeddingModelType::Llm {
|
||||
llm: crate::core::config::llm::LlmConfig {
|
||||
model: "openai/text-embedding-3-small".to_string(),
|
||||
api_key: None,
|
||||
base_url: None,
|
||||
timeout_secs: None,
|
||||
max_retries: None,
|
||||
temperature: None,
|
||||
max_tokens: None,
|
||||
},
|
||||
};
|
||||
let json = serde_json::to_string(&model_type).unwrap();
|
||||
assert!(json.contains("\"type\":\"llm\""));
|
||||
assert!(json.contains("openai/text-embedding-3-small"));
|
||||
|
||||
let deserialized: EmbeddingModelType = serde_json::from_str(&json).unwrap();
|
||||
match deserialized {
|
||||
EmbeddingModelType::Llm { llm } => {
|
||||
assert_eq!(llm.model, "openai/text-embedding-3-small");
|
||||
}
|
||||
_ => panic!("Expected Llm variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "topic_threshold must be in [0.0, 1.0]")]
|
||||
fn test_with_topic_threshold_panics_above_one() {
|
||||
ChunkingConfig::default().with_topic_threshold(1.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic(expected = "topic_threshold must be in [0.0, 1.0]")]
|
||||
fn test_with_topic_threshold_panics_below_zero() {
|
||||
ChunkingConfig::default().with_topic_threshold(-0.1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_topic_threshold_accepts_boundary_values() {
|
||||
let config = ChunkingConfig::default().with_topic_threshold(0.0);
|
||||
assert_eq!(config.topic_threshold, Some(0.0));
|
||||
|
||||
let config = ChunkingConfig::default().with_topic_threshold(1.0);
|
||||
assert_eq!(config.topic_threshold, Some(1.0));
|
||||
}
|
||||
|
||||
/// Tests Custom model type deserialization.
|
||||
#[test]
|
||||
fn test_embedding_model_type_custom_deserialization() {
|
||||
let json = r#"{"type": "custom", "model_id": "test/model", "dimensions": 512}"#;
|
||||
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
|
||||
|
||||
match model {
|
||||
EmbeddingModelType::Custom { model_id, dimensions } => {
|
||||
assert_eq!(model_id, "test/model");
|
||||
assert_eq!(dimensions, 512);
|
||||
}
|
||||
_ => panic!("Expected Custom variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_model_type_plugin_roundtrip() {
|
||||
let model = EmbeddingModelType::Plugin {
|
||||
name: "lilbee-llamacpp".to_string(),
|
||||
};
|
||||
let json = serde_json::to_string(&model).unwrap();
|
||||
assert!(json.contains("\"type\":\"plugin\""));
|
||||
assert!(json.contains("lilbee-llamacpp"));
|
||||
|
||||
let deserialized: EmbeddingModelType = serde_json::from_str(&json).unwrap();
|
||||
match deserialized {
|
||||
EmbeddingModelType::Plugin { name } => assert_eq!(name, "lilbee-llamacpp"),
|
||||
_ => panic!("Expected Plugin variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_embedding_model_type_plugin_deserialization() {
|
||||
let json = r#"{"type": "plugin", "name": "my-embedder"}"#;
|
||||
let model: EmbeddingModelType = serde_json::from_str(json).unwrap();
|
||||
match model {
|
||||
EmbeddingModelType::Plugin { name } => assert_eq!(name, "my-embedder"),
|
||||
_ => panic!("Expected Plugin variant"),
|
||||
}
|
||||
}
|
||||
|
||||
// --- Issue #797 regression tests ---
|
||||
|
||||
/// Preset with no explicit embedding: embedding must remain None.
|
||||
///
|
||||
/// Before the fix, `resolve_preset()` would silently inject an
|
||||
/// `EmbeddingConfig` whenever a preset was configured, causing every
|
||||
/// chunk to have an unexpected `.embedding` field populated.
|
||||
#[test]
|
||||
#[cfg(feature = "embeddings")]
|
||||
fn test_resolve_preset_does_not_inject_embedding_when_none() {
|
||||
let config = ChunkingConfig {
|
||||
preset: Some("multilingual".to_string()),
|
||||
embedding: None,
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = config.resolve_preset();
|
||||
assert!(
|
||||
resolved.embedding.is_none(),
|
||||
"preset alone must not inject an EmbeddingConfig (#797)"
|
||||
);
|
||||
}
|
||||
|
||||
/// Preset with an explicit embedding: the embedding must be preserved unchanged.
|
||||
#[test]
|
||||
#[cfg(feature = "embeddings")]
|
||||
fn test_resolve_preset_preserves_explicit_embedding_config() {
|
||||
let explicit = EmbeddingConfig {
|
||||
model: EmbeddingModelType::Custom {
|
||||
model_id: "my-org/model".to_string(),
|
||||
dimensions: 768,
|
||||
},
|
||||
batch_size: 16,
|
||||
..Default::default()
|
||||
};
|
||||
let config = ChunkingConfig {
|
||||
preset: Some("multilingual".to_string()),
|
||||
embedding: Some(explicit),
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = config.resolve_preset();
|
||||
let emb = resolved
|
||||
.embedding
|
||||
.expect("explicit embedding must survive resolve_preset");
|
||||
assert_eq!(emb.batch_size, 16);
|
||||
match emb.model {
|
||||
EmbeddingModelType::Custom { model_id, dimensions } => {
|
||||
assert_eq!(model_id, "my-org/model");
|
||||
assert_eq!(dimensions, 768);
|
||||
}
|
||||
other => panic!("expected Custom model type, got {other:?}"),
|
||||
}
|
||||
}
|
||||
|
||||
/// No preset, no embedding: embedding must stay None (regression guard).
|
||||
#[cfg(any(feature = "embeddings", feature = "chunking"))]
|
||||
#[test]
|
||||
fn test_resolve_preset_no_preset_no_embedding_stays_none() {
|
||||
let config = ChunkingConfig {
|
||||
preset: None,
|
||||
embedding: None,
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = config.resolve_preset();
|
||||
assert!(resolved.embedding.is_none(), "no-preset path must not touch embedding");
|
||||
}
|
||||
}
|
||||
160
crates/kreuzberg/src/core/config/tree_sitter.rs
Normal file
160
crates/kreuzberg/src/core/config/tree_sitter.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
//! Tree-sitter language pack configuration.
|
||||
//!
|
||||
//! This module contains configuration types for the tree-sitter integration,
|
||||
//! including grammar download settings and code analysis processing options.
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Content rendering mode for code extraction.
|
||||
///
|
||||
/// Controls how extracted code content is represented in the `content` field
|
||||
/// of `ExtractionResult`.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum CodeContentMode {
|
||||
/// Use TSLP semantic chunks as content (default).
|
||||
#[default]
|
||||
Chunks,
|
||||
/// Use raw source code as content.
|
||||
Raw,
|
||||
/// Emit function/class headings + docstrings (no code bodies).
|
||||
Structure,
|
||||
}
|
||||
|
||||
/// Configuration for tree-sitter language pack integration.
|
||||
///
|
||||
/// Controls grammar download behavior and code analysis options.
|
||||
///
|
||||
/// # Example (TOML)
|
||||
///
|
||||
/// ```toml
|
||||
/// [tree_sitter]
|
||||
/// languages = ["python", "rust"]
|
||||
/// groups = ["web"]
|
||||
///
|
||||
/// [tree_sitter.process]
|
||||
/// structure = true
|
||||
/// comments = true
|
||||
/// docstrings = true
|
||||
/// ```
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TreeSitterConfig {
|
||||
/// Enable code intelligence processing (default: true).
|
||||
///
|
||||
/// When `false`, tree-sitter analysis is completely skipped even if
|
||||
/// the config section is present.
|
||||
#[serde(default = "default_true")]
|
||||
pub enabled: bool,
|
||||
|
||||
/// Custom cache directory for downloaded grammars.
|
||||
///
|
||||
/// When `None`, uses the default: `~/.cache/tree-sitter-language-pack/v{version}/libs/`.
|
||||
#[serde(default)]
|
||||
pub cache_dir: Option<PathBuf>,
|
||||
|
||||
/// Languages to pre-download on init (e.g., `["python", "rust"]`).
|
||||
#[serde(default)]
|
||||
pub languages: Option<Vec<String>>,
|
||||
|
||||
/// Language groups to pre-download (e.g., `["web", "systems", "scripting"]`).
|
||||
#[serde(default)]
|
||||
pub groups: Option<Vec<String>>,
|
||||
|
||||
/// Processing options for code analysis.
|
||||
#[serde(default)]
|
||||
pub process: TreeSitterProcessConfig,
|
||||
}
|
||||
|
||||
/// Processing options for tree-sitter code analysis.
|
||||
///
|
||||
/// Controls which analysis features are enabled when extracting code files.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TreeSitterProcessConfig {
|
||||
/// Extract structural items (functions, classes, structs, etc.). Default: true.
|
||||
#[serde(default = "default_true")]
|
||||
pub structure: bool,
|
||||
|
||||
/// Extract import statements. Default: true.
|
||||
#[serde(default = "default_true")]
|
||||
pub imports: bool,
|
||||
|
||||
/// Extract export statements. Default: true.
|
||||
#[serde(default = "default_true")]
|
||||
pub exports: bool,
|
||||
|
||||
/// Extract comments. Default: false.
|
||||
#[serde(default)]
|
||||
pub comments: bool,
|
||||
|
||||
/// Extract docstrings. Default: false.
|
||||
#[serde(default)]
|
||||
pub docstrings: bool,
|
||||
|
||||
/// Extract symbol definitions. Default: false.
|
||||
#[serde(default)]
|
||||
pub symbols: bool,
|
||||
|
||||
/// Include parse diagnostics. Default: false.
|
||||
#[serde(default)]
|
||||
pub diagnostics: bool,
|
||||
|
||||
/// Maximum chunk size in bytes. `None` disables chunking.
|
||||
#[serde(default)]
|
||||
pub chunk_max_size: Option<usize>,
|
||||
|
||||
/// Content rendering mode for code extraction.
|
||||
#[serde(default)]
|
||||
pub content_mode: CodeContentMode,
|
||||
}
|
||||
|
||||
impl Default for TreeSitterConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
enabled: true,
|
||||
cache_dir: None,
|
||||
languages: None,
|
||||
groups: None,
|
||||
process: TreeSitterProcessConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for TreeSitterProcessConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
structure: true,
|
||||
imports: true,
|
||||
exports: true,
|
||||
comments: false,
|
||||
docstrings: false,
|
||||
symbols: false,
|
||||
diagnostics: false,
|
||||
chunk_max_size: None,
|
||||
content_mode: CodeContentMode::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn default_true() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
/// Convert kreuzberg's process config to TSLP's `ProcessConfig`.
|
||||
///
|
||||
/// The language field is left empty — callers must set it before use.
|
||||
impl From<&TreeSitterProcessConfig> for tree_sitter_language_pack::ProcessConfig {
|
||||
fn from(p: &TreeSitterProcessConfig) -> Self {
|
||||
Self {
|
||||
language: std::borrow::Cow::Borrowed(""),
|
||||
structure: p.structure,
|
||||
imports: p.imports,
|
||||
exports: p.exports,
|
||||
comments: p.comments,
|
||||
docstrings: p.docstrings,
|
||||
symbols: p.symbols,
|
||||
diagnostics: p.diagnostics,
|
||||
chunk_max_size: p.chunk_max_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
97
crates/kreuzberg/src/core/config_validation/dependencies.rs
Normal file
97
crates/kreuzberg/src/core/config_validation/dependencies.rs
Normal file
@@ -0,0 +1,97 @@
|
||||
//! Cross-section dependency validation.
|
||||
//!
|
||||
//! This module contains validation functions that check dependencies and relationships
|
||||
//! between different configuration sections. These validators ensure that related
|
||||
//! configuration values are consistent and compatible with each other.
|
||||
|
||||
#[cfg(test)]
|
||||
use crate::{KreuzbergError, Result};
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_port(port: u32) -> Result<()> {
|
||||
if port == 0 || port > 65535 {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!("Port must be 1-65535, got {}", port),
|
||||
source: None,
|
||||
})
|
||||
} else {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_host(host: &str) -> Result<()> {
|
||||
let host = host.trim();
|
||||
|
||||
if host.is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "Invalid host '': must be a valid IP address or hostname".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
// Check if it's a valid IPv4 address
|
||||
if host.parse::<std::net::Ipv4Addr>().is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check if it's a valid IPv6 address
|
||||
if host.parse::<std::net::Ipv6Addr>().is_ok() {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
// Check if it's a valid hostname (basic validation)
|
||||
// Hostnames must contain only alphanumeric characters, dots, and hyphens
|
||||
// Must not look like an invalid IPv4 address (all numeric with dots)
|
||||
let looks_like_ipv4 = host
|
||||
.split('.')
|
||||
.all(|part| !part.is_empty() && part.chars().all(|c| c.is_numeric()));
|
||||
if !looks_like_ipv4
|
||||
&& host.chars().all(|c| c.is_alphanumeric() || c == '.' || c == '-')
|
||||
&& !host.starts_with('-')
|
||||
&& !host.ends_with('-')
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!("Invalid host '{}': must be a valid IP address or hostname", host),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_cors_origin(origin: &str) -> Result<()> {
|
||||
let origin = origin.trim();
|
||||
|
||||
if origin == "*" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if origin.starts_with("http://") || origin.starts_with("https://") {
|
||||
// Basic validation: ensure there's something after the protocol
|
||||
if origin.len() > 8 && (origin.starts_with("http://") && origin.len() > 7 || origin.starts_with("https://")) {
|
||||
return Ok(());
|
||||
}
|
||||
}
|
||||
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid CORS origin '{}': must be a valid HTTP/HTTPS URL or '*'",
|
||||
origin
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_upload_size(size: usize) -> Result<()> {
|
||||
if size > 0 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!("Upload size must be greater than 0, got {}", size),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
399
crates/kreuzberg/src/core/config_validation/mod.rs
Normal file
399
crates/kreuzberg/src/core/config_validation/mod.rs
Normal file
@@ -0,0 +1,399 @@
|
||||
//! Configuration validation module.
|
||||
//!
|
||||
//! Provides centralized validation for configuration values across all bindings.
|
||||
//! This eliminates duplication of validation logic in Python, TypeScript, Java, Go, and other language bindings.
|
||||
//!
|
||||
//! All validation functions return `Result<()>` and produce detailed error messages
|
||||
//! suitable for user-facing error handling.
|
||||
//!
|
||||
//! # Examples
|
||||
//!
|
||||
//! ```ignore
|
||||
//! use kreuzberg::core::config_validation::{
|
||||
//! validate_binarization_method,
|
||||
//! validate_token_reduction_level,
|
||||
//! validate_language_code,
|
||||
//! };
|
||||
//!
|
||||
//! // Valid values
|
||||
//! assert!(validate_binarization_method("otsu").is_ok());
|
||||
//! assert!(validate_token_reduction_level("moderate").is_ok());
|
||||
//! assert!(validate_language_code("en").is_ok());
|
||||
//!
|
||||
//! // Invalid values
|
||||
//! assert!(validate_binarization_method("invalid").is_err());
|
||||
//! assert!(validate_token_reduction_level("extreme").is_err());
|
||||
//! ```
|
||||
|
||||
mod dependencies;
|
||||
mod sections;
|
||||
|
||||
// Re-export validation functions used in production code
|
||||
pub(crate) use sections::{
|
||||
validate_chunking_params, validate_language_code, validate_ocr_backend, validate_token_reduction_level,
|
||||
};
|
||||
|
||||
// Re-export validation functions used only in tests
|
||||
#[cfg(test)]
|
||||
pub(crate) use dependencies::{validate_cors_origin, validate_host, validate_port, validate_upload_size};
|
||||
#[cfg(test)]
|
||||
pub(crate) use sections::{
|
||||
validate_binarization_method, validate_confidence, validate_dpi, validate_output_format, validate_tesseract_oem,
|
||||
validate_tesseract_psm, validate_vlm_backend_config,
|
||||
};
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
// Tests for section validation functions
|
||||
#[test]
|
||||
fn test_validate_binarization_method_valid() {
|
||||
assert!(validate_binarization_method("otsu").is_ok());
|
||||
assert!(validate_binarization_method("adaptive").is_ok());
|
||||
assert!(validate_binarization_method("sauvola").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_binarization_method_case_insensitive() {
|
||||
assert!(validate_binarization_method("OTSU").is_ok());
|
||||
assert!(validate_binarization_method("Adaptive").is_ok());
|
||||
assert!(validate_binarization_method("SAUVOLA").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_binarization_method_invalid() {
|
||||
let result = validate_binarization_method("invalid");
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Invalid binarization method"));
|
||||
assert!(msg.contains("otsu"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_reduction_level_valid() {
|
||||
assert!(validate_token_reduction_level("off").is_ok());
|
||||
assert!(validate_token_reduction_level("light").is_ok());
|
||||
assert!(validate_token_reduction_level("moderate").is_ok());
|
||||
assert!(validate_token_reduction_level("aggressive").is_ok());
|
||||
assert!(validate_token_reduction_level("maximum").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_reduction_level_case_insensitive() {
|
||||
assert!(validate_token_reduction_level("OFF").is_ok());
|
||||
assert!(validate_token_reduction_level("Moderate").is_ok());
|
||||
assert!(validate_token_reduction_level("MAXIMUM").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_token_reduction_level_invalid() {
|
||||
let result = validate_token_reduction_level("extreme");
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Invalid token reduction level"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_ocr_backend_valid() {
|
||||
assert!(validate_ocr_backend("tesseract").is_ok());
|
||||
assert!(validate_ocr_backend("easyocr").is_ok());
|
||||
assert!(validate_ocr_backend("paddleocr").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_ocr_backend_case_insensitive() {
|
||||
assert!(validate_ocr_backend("TESSERACT").is_ok());
|
||||
assert!(validate_ocr_backend("EasyOCR").is_ok());
|
||||
assert!(validate_ocr_backend("PADDLEOCR").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_ocr_backend_invalid() {
|
||||
let result = validate_ocr_backend("invalid_backend");
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Invalid OCR backend"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_language_code_valid_iso639_1() {
|
||||
assert!(validate_language_code("en").is_ok());
|
||||
assert!(validate_language_code("de").is_ok());
|
||||
assert!(validate_language_code("fr").is_ok());
|
||||
assert!(validate_language_code("es").is_ok());
|
||||
assert!(validate_language_code("zh").is_ok());
|
||||
assert!(validate_language_code("ja").is_ok());
|
||||
assert!(validate_language_code("ko").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_language_code_valid_iso639_3() {
|
||||
assert!(validate_language_code("eng").is_ok());
|
||||
assert!(validate_language_code("deu").is_ok());
|
||||
assert!(validate_language_code("fra").is_ok());
|
||||
assert!(validate_language_code("spa").is_ok());
|
||||
assert!(validate_language_code("zho").is_ok());
|
||||
assert!(validate_language_code("jpn").is_ok());
|
||||
assert!(validate_language_code("kor").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_language_code_case_insensitive() {
|
||||
assert!(validate_language_code("EN").is_ok());
|
||||
assert!(validate_language_code("ENG").is_ok());
|
||||
assert!(validate_language_code("De").is_ok());
|
||||
assert!(validate_language_code("DEU").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_language_code_all_keyword() {
|
||||
assert!(validate_language_code("all").is_ok());
|
||||
assert!(validate_language_code("ALL").is_ok());
|
||||
assert!(validate_language_code("All").is_ok());
|
||||
assert!(validate_language_code("*").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_language_code_invalid() {
|
||||
let result = validate_language_code("invalid");
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Invalid language code"));
|
||||
assert!(msg.contains("ISO 639"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tesseract_psm_valid() {
|
||||
for psm in 0..=13 {
|
||||
assert!(validate_tesseract_psm(psm).is_ok(), "PSM {} should be valid", psm);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tesseract_psm_invalid() {
|
||||
assert!(validate_tesseract_psm(-1).is_err());
|
||||
assert!(validate_tesseract_psm(14).is_err());
|
||||
assert!(validate_tesseract_psm(100).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tesseract_oem_valid() {
|
||||
for oem in 0..=3 {
|
||||
assert!(validate_tesseract_oem(oem).is_ok(), "OEM {} should be valid", oem);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_tesseract_oem_invalid() {
|
||||
assert!(validate_tesseract_oem(-1).is_err());
|
||||
assert!(validate_tesseract_oem(4).is_err());
|
||||
assert!(validate_tesseract_oem(10).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_output_format_valid() {
|
||||
assert!(validate_output_format("text").is_ok());
|
||||
assert!(validate_output_format("markdown").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_output_format_case_insensitive() {
|
||||
assert!(validate_output_format("TEXT").is_ok());
|
||||
assert!(validate_output_format("Markdown").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_output_format_invalid() {
|
||||
let result = validate_output_format("xml");
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Invalid output format"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_confidence_valid() {
|
||||
assert!(validate_confidence(0.0).is_ok());
|
||||
assert!(validate_confidence(0.5).is_ok());
|
||||
assert!(validate_confidence(1.0).is_ok());
|
||||
assert!(validate_confidence(0.75).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_confidence_invalid() {
|
||||
assert!(validate_confidence(-0.1).is_err());
|
||||
assert!(validate_confidence(1.1).is_err());
|
||||
assert!(validate_confidence(2.0).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_dpi_valid() {
|
||||
assert!(validate_dpi(72).is_ok());
|
||||
assert!(validate_dpi(96).is_ok());
|
||||
assert!(validate_dpi(300).is_ok());
|
||||
assert!(validate_dpi(600).is_ok());
|
||||
assert!(validate_dpi(1).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_dpi_invalid() {
|
||||
assert!(validate_dpi(0).is_err());
|
||||
assert!(validate_dpi(-1).is_err());
|
||||
assert!(validate_dpi(2401).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_chunking_params_valid() {
|
||||
assert!(validate_chunking_params(1000, 200).is_ok());
|
||||
assert!(validate_chunking_params(500, 50).is_ok());
|
||||
assert!(validate_chunking_params(1, 0).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_chunking_params_zero_chars() {
|
||||
let result = validate_chunking_params(0, 100);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("max_chars"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_chunking_params_overlap_too_large() {
|
||||
let result = validate_chunking_params(100, 100);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("overlap"));
|
||||
|
||||
let result = validate_chunking_params(100, 150);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_messages_are_helpful() {
|
||||
let err = validate_binarization_method("bad").unwrap_err().to_string();
|
||||
assert!(err.contains("otsu"));
|
||||
assert!(err.contains("adaptive"));
|
||||
assert!(err.contains("sauvola"));
|
||||
|
||||
let err = validate_token_reduction_level("bad").unwrap_err().to_string();
|
||||
assert!(err.contains("off"));
|
||||
assert!(err.contains("moderate"));
|
||||
|
||||
let err = validate_language_code("bad").unwrap_err().to_string();
|
||||
assert!(err.contains("ISO 639"));
|
||||
assert!(err.contains("en"));
|
||||
}
|
||||
|
||||
// Tests for dependency validation functions
|
||||
#[test]
|
||||
fn test_validate_port_valid() {
|
||||
assert!(validate_port(1).is_ok());
|
||||
assert!(validate_port(80).is_ok());
|
||||
assert!(validate_port(443).is_ok());
|
||||
assert!(validate_port(8000).is_ok());
|
||||
assert!(validate_port(65535).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_port_invalid() {
|
||||
let result = validate_port(0);
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Port must be 1-65535"));
|
||||
assert!(msg.contains("0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_host_ipv4() {
|
||||
assert!(validate_host("127.0.0.1").is_ok());
|
||||
assert!(validate_host("0.0.0.0").is_ok());
|
||||
assert!(validate_host("192.168.1.1").is_ok());
|
||||
assert!(validate_host("10.0.0.1").is_ok());
|
||||
assert!(validate_host("255.255.255.255").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_host_ipv6() {
|
||||
assert!(validate_host("::1").is_ok());
|
||||
assert!(validate_host("::").is_ok());
|
||||
assert!(validate_host("2001:db8::1").is_ok());
|
||||
assert!(validate_host("fe80::1").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_host_hostname() {
|
||||
assert!(validate_host("localhost").is_ok());
|
||||
assert!(validate_host("example.com").is_ok());
|
||||
assert!(validate_host("sub.example.com").is_ok());
|
||||
assert!(validate_host("api-server").is_ok());
|
||||
assert!(validate_host("app123").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_host_invalid() {
|
||||
let result = validate_host("");
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Invalid host"));
|
||||
|
||||
let result = validate_host("not a valid host");
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = validate_host("256.256.256.256");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cors_origin_https() {
|
||||
assert!(validate_cors_origin("https://example.com").is_ok());
|
||||
assert!(validate_cors_origin("https://localhost:3000").is_ok());
|
||||
assert!(validate_cors_origin("https://sub.example.com").is_ok());
|
||||
assert!(validate_cors_origin("https://192.168.1.1").is_ok());
|
||||
assert!(validate_cors_origin("https://example.com/path").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cors_origin_http() {
|
||||
assert!(validate_cors_origin("http://example.com").is_ok());
|
||||
assert!(validate_cors_origin("http://localhost:3000").is_ok());
|
||||
assert!(validate_cors_origin("http://127.0.0.1:8000").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cors_origin_wildcard() {
|
||||
assert!(validate_cors_origin("*").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_cors_origin_invalid() {
|
||||
let result = validate_cors_origin("not-a-url");
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Invalid CORS origin"));
|
||||
|
||||
let result = validate_cors_origin("ftp://example.com");
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = validate_cors_origin("example.com");
|
||||
assert!(result.is_err());
|
||||
|
||||
let result = validate_cors_origin("http://");
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_upload_size_valid() {
|
||||
assert!(validate_upload_size(1).is_ok());
|
||||
assert!(validate_upload_size(1024).is_ok());
|
||||
assert!(validate_upload_size(1_000_000).is_ok());
|
||||
assert!(validate_upload_size(1_000_000_000).is_ok());
|
||||
assert!(validate_upload_size(usize::MAX).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_upload_size_invalid() {
|
||||
let result = validate_upload_size(0);
|
||||
assert!(result.is_err());
|
||||
let msg = result.unwrap_err().to_string();
|
||||
assert!(msg.contains("Upload size must be greater than 0"));
|
||||
assert!(msg.contains("0"));
|
||||
}
|
||||
}
|
||||
574
crates/kreuzberg/src/core/config_validation/sections.rs
Normal file
574
crates/kreuzberg/src/core/config_validation/sections.rs
Normal file
@@ -0,0 +1,574 @@
|
||||
//! Per-section validation functions.
|
||||
//!
|
||||
//! This module contains validation functions for individual configuration sections
|
||||
//! and their specific parameters. Each function validates a specific aspect of
|
||||
//! the configuration and returns detailed error messages when validation fails.
|
||||
|
||||
use crate::{KreuzbergError, Result};
|
||||
|
||||
/// Valid binarization methods for image preprocessing.
|
||||
#[cfg(test)]
|
||||
const VALID_BINARIZATION_METHODS: &[&str] = &["otsu", "adaptive", "sauvola"];
|
||||
|
||||
/// Valid token reduction levels.
|
||||
const VALID_TOKEN_REDUCTION_LEVELS: &[&str] = &["off", "light", "moderate", "aggressive", "maximum"];
|
||||
|
||||
/// Valid OCR backends.
|
||||
const VALID_OCR_BACKENDS: &[&str] = &["tesseract", "easyocr", "paddleocr", "paddle-ocr", "vlm"];
|
||||
|
||||
/// Common ISO 639-1 language codes (extended list).
|
||||
/// Covers most major languages and variants used in document processing.
|
||||
const VALID_LANGUAGE_CODES: &[&str] = &[
|
||||
"en",
|
||||
"de",
|
||||
"fr",
|
||||
"es",
|
||||
"it",
|
||||
"pt",
|
||||
"nl",
|
||||
"pl",
|
||||
"ru",
|
||||
"zh",
|
||||
"ja",
|
||||
"ko",
|
||||
"bg",
|
||||
"cs",
|
||||
"da",
|
||||
"el",
|
||||
"et",
|
||||
"fi",
|
||||
"hu",
|
||||
"lt",
|
||||
"lv",
|
||||
"ro",
|
||||
"sk",
|
||||
"sl",
|
||||
"sv",
|
||||
"uk",
|
||||
"ar",
|
||||
"hi",
|
||||
"th",
|
||||
"tr",
|
||||
"vi",
|
||||
"eng",
|
||||
"deu",
|
||||
"fra",
|
||||
"spa",
|
||||
"ita",
|
||||
"por",
|
||||
"nld",
|
||||
"pol",
|
||||
"rus",
|
||||
"zho",
|
||||
"jpn",
|
||||
"kor",
|
||||
"ces",
|
||||
"dan",
|
||||
"ell",
|
||||
"est",
|
||||
"fin",
|
||||
"hun",
|
||||
"lit",
|
||||
"lav",
|
||||
"ron",
|
||||
"slk",
|
||||
"slv",
|
||||
"swe",
|
||||
"tur",
|
||||
// PaddleOCR-specific language codes (non-ISO but widely used)
|
||||
"ch",
|
||||
"chinese_cht",
|
||||
"latin",
|
||||
"cyrillic",
|
||||
"devanagari",
|
||||
"arabic",
|
||||
];
|
||||
|
||||
/// Valid tesseract PSM (Page Segmentation Mode) values.
|
||||
#[cfg(test)]
|
||||
const VALID_TESSERACT_PSM: &[i32] = &[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13];
|
||||
|
||||
/// Valid tesseract OEM (OCR Engine Mode) values.
|
||||
#[cfg(test)]
|
||||
const VALID_TESSERACT_OEM: &[i32] = &[0, 1, 2, 3];
|
||||
|
||||
/// Valid output formats for document extraction.
|
||||
/// Supports plain text, markdown, djot, HTML, and structured (JSON) output formats.
|
||||
/// Also accepts aliases: "text" for "plain", "md" for "markdown", "json" for "structured".
|
||||
#[cfg(test)]
|
||||
const VALID_OUTPUT_FORMATS: &[&str] = &["plain", "text", "markdown", "md", "djot", "html", "structured", "json"];
|
||||
|
||||
/// Validate a binarization method string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `method` - The binarization method to validate (e.g., "otsu", "adaptive", "sauvola")
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the method is valid, or a `ValidationError` with details about valid options.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_binarization_method;
|
||||
///
|
||||
/// assert!(validate_binarization_method("otsu").is_ok());
|
||||
/// assert!(validate_binarization_method("adaptive").is_ok());
|
||||
/// assert!(validate_binarization_method("invalid").is_err());
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_binarization_method(method: &str) -> Result<()> {
|
||||
let method = method.to_lowercase();
|
||||
if VALID_BINARIZATION_METHODS.contains(&method.as_str()) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid binarization method '{}'. Valid options are: {}",
|
||||
method,
|
||||
VALID_BINARIZATION_METHODS.join(", ")
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a token reduction level string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `level` - The token reduction level to validate (e.g., "off", "light", "moderate")
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the level is valid, or a `ValidationError` with details about valid options.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::core::config_validation::validate_token_reduction_level;
|
||||
///
|
||||
/// assert!(validate_token_reduction_level("off").is_ok());
|
||||
/// assert!(validate_token_reduction_level("moderate").is_ok());
|
||||
/// assert!(validate_token_reduction_level("extreme").is_err());
|
||||
/// ```
|
||||
pub(crate) fn validate_token_reduction_level(level: &str) -> Result<()> {
|
||||
let level = level.to_lowercase();
|
||||
if VALID_TOKEN_REDUCTION_LEVELS.contains(&level.as_str()) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid token reduction level '{}'. Valid options are: {}",
|
||||
level,
|
||||
VALID_TOKEN_REDUCTION_LEVELS.join(", ")
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate an OCR backend string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `backend` - The OCR backend to validate (e.g., "tesseract", "easyocr", "paddleocr")
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the backend is valid, or a `ValidationError` with details about valid options.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::core::config_validation::validate_ocr_backend;
|
||||
///
|
||||
/// assert!(validate_ocr_backend("tesseract").is_ok());
|
||||
/// assert!(validate_ocr_backend("easyocr").is_ok());
|
||||
/// assert!(validate_ocr_backend("invalid").is_err());
|
||||
/// ```
|
||||
pub(crate) fn validate_ocr_backend(backend: &str) -> Result<()> {
|
||||
let backend = backend.to_lowercase();
|
||||
if VALID_OCR_BACKENDS.contains(&backend.as_str()) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid OCR backend '{}'. Valid options are: {}",
|
||||
backend,
|
||||
VALID_OCR_BACKENDS.join(", ")
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a language code (ISO 639-1 or 639-3 format).
|
||||
///
|
||||
/// Accepts both 2-letter ISO 639-1 codes (e.g., "en", "de") and
|
||||
/// 3-letter ISO 639-3 codes (e.g., "eng", "deu") for broader compatibility.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `code` - The language code to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the code is valid, or a `ValidationError` indicating an invalid language code.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::core::config_validation::validate_language_code;
|
||||
///
|
||||
/// assert!(validate_language_code("en").is_ok());
|
||||
/// assert!(validate_language_code("eng").is_ok());
|
||||
/// assert!(validate_language_code("de").is_ok());
|
||||
/// assert!(validate_language_code("deu").is_ok());
|
||||
/// assert!(validate_language_code("invalid").is_err());
|
||||
/// ```
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub(crate) fn validate_language_code(code: &str) -> Result<()> {
|
||||
let code_lower = code.to_lowercase();
|
||||
|
||||
// Accept "all" and "*" as special values to auto-detect installed languages
|
||||
if code_lower == "all" || code_lower == "*" {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if VALID_LANGUAGE_CODES.contains(&code_lower.as_str()) {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid language code '{}'. Use ISO 639-1 (2-letter, e.g., 'en', 'de') \
|
||||
or ISO 639-3 (3-letter, e.g., 'eng', 'deu') codes. \
|
||||
Common codes: en, de, fr, es, it, pt, nl, pl, ru, zh, ja, ko, ar, hi, th.",
|
||||
code
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Validate a tesseract Page Segmentation Mode (PSM).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `psm` - The PSM value to validate (0-13)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the PSM is valid, or a `ValidationError` with details about valid ranges.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_tesseract_psm;
|
||||
///
|
||||
/// assert!(validate_tesseract_psm(3).is_ok()); // Fully automatic
|
||||
/// assert!(validate_tesseract_psm(6).is_ok()); // Single block of text
|
||||
/// assert!(validate_tesseract_psm(14).is_err()); // Out of range
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_tesseract_psm(psm: i32) -> Result<()> {
|
||||
if VALID_TESSERACT_PSM.contains(&psm) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid tesseract PSM value '{}'. Valid range is 0-13. \
|
||||
Common values: 3 (auto), 6 (single block), 11 (sparse text).",
|
||||
psm
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a tesseract OCR Engine Mode (OEM).
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `oem` - The OEM value to validate (0-3)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the OEM is valid, or a `ValidationError` with details about valid options.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_tesseract_oem;
|
||||
///
|
||||
/// assert!(validate_tesseract_oem(1).is_ok()); // Neural nets (LSTM)
|
||||
/// assert!(validate_tesseract_oem(2).is_ok()); // Legacy + LSTM
|
||||
/// assert!(validate_tesseract_oem(4).is_err()); // Out of range
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_tesseract_oem(oem: i32) -> Result<()> {
|
||||
if VALID_TESSERACT_OEM.contains(&oem) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid tesseract OEM value '{}'. Valid range is 0-3. \
|
||||
0=Legacy, 1=LSTM, 2=Legacy+LSTM, 3=Default",
|
||||
oem
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a document extraction output format.
|
||||
///
|
||||
/// Accepts the following formats and aliases:
|
||||
/// - "plain" or "text" for plain text output
|
||||
/// - "markdown" or "md" for Markdown output
|
||||
/// - "djot" for Djot markup format
|
||||
/// - "html" for HTML output
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `format` - The output format to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the format is valid, or a `ValidationError` with details about valid options.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_output_format;
|
||||
///
|
||||
/// assert!(validate_output_format("text").is_ok());
|
||||
/// assert!(validate_output_format("plain").is_ok());
|
||||
/// assert!(validate_output_format("markdown").is_ok());
|
||||
/// assert!(validate_output_format("md").is_ok());
|
||||
/// assert!(validate_output_format("djot").is_ok());
|
||||
/// assert!(validate_output_format("html").is_ok());
|
||||
/// assert!(validate_output_format("json").is_ok());
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_output_format(format: &str) -> Result<()> {
|
||||
let format = format.to_lowercase();
|
||||
if VALID_OUTPUT_FORMATS.contains(&format.as_str()) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid output format '{}'. Valid options are: {}",
|
||||
format,
|
||||
VALID_OUTPUT_FORMATS.join(", ")
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a confidence threshold value.
|
||||
///
|
||||
/// Confidence thresholds should be between 0.0 and 1.0 inclusive.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `confidence` - The confidence threshold to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the confidence is valid, or a `ValidationError` with details about valid ranges.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_confidence;
|
||||
///
|
||||
/// assert!(validate_confidence(0.5).is_ok());
|
||||
/// assert!(validate_confidence(0.0).is_ok());
|
||||
/// assert!(validate_confidence(1.0).is_ok());
|
||||
/// assert!(validate_confidence(1.5).is_err());
|
||||
/// assert!(validate_confidence(-0.1).is_err());
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_confidence(confidence: f64) -> Result<()> {
|
||||
if (0.0..=1.0).contains(&confidence) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid confidence threshold '{}'. Must be between 0.0 and 1.0.",
|
||||
confidence
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate a DPI (dots per inch) value.
|
||||
///
|
||||
/// DPI should be a positive integer, typically 72-600.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dpi` - The DPI value to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the DPI is valid, or a `ValidationError` with details about valid ranges.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_dpi;
|
||||
///
|
||||
/// assert!(validate_dpi(96).is_ok());
|
||||
/// assert!(validate_dpi(300).is_ok());
|
||||
/// assert!(validate_dpi(0).is_err());
|
||||
/// assert!(validate_dpi(-1).is_err());
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_dpi(dpi: i32) -> Result<()> {
|
||||
if dpi > 0 && dpi <= 2400 {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"Invalid DPI value '{}'. Must be a positive integer, typically 72-600.",
|
||||
dpi
|
||||
),
|
||||
source: None,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Validate chunk size parameters.
|
||||
///
|
||||
/// Checks that max_chars > 0 and max_overlap < max_chars.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `max_chars` - The maximum characters per chunk
|
||||
/// * `max_overlap` - The maximum overlap between chunks
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the parameters are valid, or a `ValidationError` with details about constraints.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```ignore
|
||||
/// use kreuzberg::core::config_validation::validate_chunking_params;
|
||||
///
|
||||
/// assert!(validate_chunking_params(1000, 200).is_ok());
|
||||
/// assert!(validate_chunking_params(500, 50).is_ok());
|
||||
/// assert!(validate_chunking_params(0, 100).is_err()); // max_chars must be > 0
|
||||
/// assert!(validate_chunking_params(100, 150).is_err()); // overlap >= max_chars
|
||||
/// ```
|
||||
pub(crate) fn validate_chunking_params(max_chars: usize, max_overlap: usize) -> Result<()> {
|
||||
if max_chars == 0 {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "max_chars must be greater than 0".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
if max_overlap >= max_chars {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: format!(
|
||||
"max_overlap ({}) must be less than max_chars ({})",
|
||||
max_overlap, max_chars
|
||||
),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate that an [`LlmConfig`](crate::core::config::LlmConfig) has a non-empty model string.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `model` - The model string to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the model is non-empty, or a `ValidationError` otherwise.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_llm_config_model;
|
||||
///
|
||||
/// assert!(validate_llm_config_model("openai/gpt-4o").is_ok());
|
||||
/// assert!(validate_llm_config_model("").is_err());
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_llm_config_model(model: &str) -> Result<()> {
|
||||
if model.trim().is_empty() {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "LLM config 'model' must not be empty. Provide a model identifier (e.g., 'openai/gpt-4o')."
|
||||
.to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Validate that a VLM OCR backend has the required `vlm_config`.
|
||||
///
|
||||
/// When the OCR backend is set to `"vlm"`, the `vlm_config` field must be present
|
||||
/// to provide the model endpoint configuration, and the model string must be non-empty.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `backend` - The OCR backend name
|
||||
/// * `vlm_config` - The optional VLM config to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `Ok(())` if the backend is not `"vlm"` or `vlm_config` is present with a valid model,
|
||||
/// or a `ValidationError` if `"vlm"` backend is used without `vlm_config` or with an empty model.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::config_validation::validate_vlm_backend_config;
|
||||
/// use kreuzberg::core::config::LlmConfig;
|
||||
///
|
||||
/// assert!(validate_vlm_backend_config("tesseract", None).is_ok());
|
||||
/// let config = LlmConfig {
|
||||
/// model: "openai/gpt-4o".to_string(),
|
||||
/// api_key: None,
|
||||
/// base_url: None,
|
||||
/// timeout_secs: None,
|
||||
/// max_retries: None,
|
||||
/// temperature: None,
|
||||
/// max_tokens: None,
|
||||
/// };
|
||||
/// assert!(validate_vlm_backend_config("vlm", Some(&config)).is_ok());
|
||||
/// assert!(validate_vlm_backend_config("vlm", None).is_err());
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
pub(crate) fn validate_vlm_backend_config(
|
||||
backend: &str,
|
||||
vlm_config: Option<&crate::core::config::LlmConfig>,
|
||||
) -> Result<()> {
|
||||
if backend.to_lowercase() == "vlm" {
|
||||
match vlm_config {
|
||||
None => {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "OCR backend 'vlm' requires 'vlm_config' to be set with model endpoint configuration."
|
||||
.to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
Some(config) => {
|
||||
validate_llm_config_model(&config.model)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
331
crates/kreuzberg/src/core/extractor/batch.rs
Normal file
331
crates/kreuzberg/src/core/extractor/batch.rs
Normal file
@@ -0,0 +1,331 @@
|
||||
//! Batch extraction operations for concurrent processing.
|
||||
//!
|
||||
//! This module provides parallel extraction capabilities for processing
|
||||
//! multiple files or byte arrays concurrently with automatic resource management.
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use crate::core::config::BatchBytesItem;
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use crate::core::config::BatchFileItem;
|
||||
use crate::core::config::ExtractionConfig;
|
||||
use crate::core::config::extraction::FileExtractionConfig;
|
||||
use crate::types::ExtractionResult;
|
||||
use crate::{KreuzbergError, Result};
|
||||
use std::future::Future;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
use super::bytes::extract_bytes;
|
||||
use super::file::extract_file;
|
||||
use super::helpers::error_extraction_result;
|
||||
|
||||
/// Shared batch result collection: spawns tasks via callback, collects ordered results.
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
async fn collect_batch<F, Fut>(count: usize, config: &ExtractionConfig, spawn_task: F) -> Result<Vec<ExtractionResult>>
|
||||
where
|
||||
F: Fn(usize, Arc<tokio::sync::Semaphore>) -> Fut,
|
||||
Fut: Future<Output = (usize, Result<ExtractionResult>, u64)> + Send + 'static,
|
||||
{
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
if count == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
||||
let max_concurrent = config
|
||||
.max_concurrent_extractions
|
||||
.or_else(|| config.concurrency.as_ref().and_then(|c| c.max_threads))
|
||||
.unwrap_or_else(|| crate::core::config::concurrency::resolve_thread_budget(config.concurrency.as_ref()));
|
||||
let semaphore = Arc::new(Semaphore::new(max_concurrent));
|
||||
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
for index in 0..count {
|
||||
let sem = Arc::clone(&semaphore);
|
||||
tasks.spawn(spawn_task(index, sem));
|
||||
}
|
||||
|
||||
let mut results: Vec<Option<ExtractionResult>> = vec![None; count];
|
||||
|
||||
while let Some(task_result) = tasks.join_next().await {
|
||||
match task_result {
|
||||
Ok((index, Ok(result), _elapsed_ms)) => {
|
||||
results[index] = Some(result);
|
||||
}
|
||||
Ok((index, Err(e), elapsed_ms)) => {
|
||||
results[index] = Some(error_extraction_result(&e, Some(elapsed_ms)));
|
||||
}
|
||||
Err(join_err) => {
|
||||
return Err(KreuzbergError::Other(format!("Task panicked: {}", join_err)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::unwrap_used)]
|
||||
Ok(results.into_iter().map(|r| r.unwrap()).collect())
|
||||
}
|
||||
|
||||
/// Run a single extraction task with semaphore gating, timing, optional timeout, and batch mode.
|
||||
///
|
||||
/// When `cancel_token` is provided and the timeout fires, the token is signalled so that
|
||||
/// any blocking PDF operations in progress can observe the cancellation at the next
|
||||
/// inter-page checkpoint and stop early.
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
async fn run_timed_extraction<F, Fut>(
|
||||
index: usize,
|
||||
semaphore: Arc<tokio::sync::Semaphore>,
|
||||
timeout_secs: Option<u64>,
|
||||
cancel_token: Option<crate::cancellation::CancellationToken>,
|
||||
extract_fn: F,
|
||||
) -> (usize, Result<ExtractionResult>, u64)
|
||||
where
|
||||
F: FnOnce() -> Fut,
|
||||
Fut: Future<Output = Result<ExtractionResult>>,
|
||||
{
|
||||
let _permit = semaphore.acquire().await.unwrap();
|
||||
let start = Instant::now();
|
||||
|
||||
let extraction_future = crate::core::batch_mode::with_batch_mode(extract_fn());
|
||||
|
||||
let mut result = match timeout_secs {
|
||||
Some(secs) => match tokio::time::timeout(std::time::Duration::from_secs(secs), extraction_future).await {
|
||||
Ok(inner) => inner,
|
||||
Err(_elapsed) => {
|
||||
// Signal the cancellation token so that any blocking PDF thread can
|
||||
// detect it at the next inter-page checkpoint and stop processing.
|
||||
if let Some(ref token) = cancel_token {
|
||||
token.cancel();
|
||||
}
|
||||
let elapsed_ms = start.elapsed().as_millis() as u64;
|
||||
Err(KreuzbergError::Timeout {
|
||||
elapsed_ms,
|
||||
limit_ms: secs * 1000,
|
||||
})
|
||||
}
|
||||
},
|
||||
None => extraction_future.await,
|
||||
};
|
||||
|
||||
let elapsed_ms = start.elapsed().as_millis() as u64;
|
||||
|
||||
if let Ok(ref mut r) = result {
|
||||
r.metadata.extraction_duration_ms = Some(elapsed_ms);
|
||||
}
|
||||
|
||||
(index, result, elapsed_ms)
|
||||
}
|
||||
|
||||
/// Resolve a per-file config against a base config. Returns owned config.
|
||||
fn resolve_config(base: &ExtractionConfig, file_config: &Option<FileExtractionConfig>) -> ExtractionConfig {
|
||||
match file_config {
|
||||
Some(fc) => base.with_file_overrides(fc),
|
||||
None => base.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract content from multiple files concurrently.
|
||||
///
|
||||
/// This function processes multiple files in parallel, automatically managing
|
||||
/// concurrency to prevent resource exhaustion. The concurrency limit can be
|
||||
/// configured via `ExtractionConfig::max_concurrent_extractions` or defaults
|
||||
/// to `(num_cpus * 1.5).ceil()`.
|
||||
///
|
||||
/// Each file can optionally specify a [`FileExtractionConfig`] that overrides specific
|
||||
/// fields from the batch-level `config`. Pass `None` for a file to use the batch defaults.
|
||||
/// Batch-level settings like `max_concurrent_extractions` and `use_cache` are always
|
||||
/// taken from the batch-level `config`.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - Vector of `BatchFileItem` structs, each containing a path and optional
|
||||
/// per-file configuration overrides.
|
||||
/// * `config` - Batch-level extraction configuration (provides defaults and batch settings)
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of `ExtractionResult` in the same order as the input items.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Individual file errors are captured in the result metadata. System errors
|
||||
/// (IO, RuntimeError equivalents) will bubble up and fail the entire batch.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Simple usage with no per-file overrides:
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::batch_extract_files;
|
||||
/// use kreuzberg::core::config::{ExtractionConfig, BatchFileItem};
|
||||
/// use std::path::PathBuf;
|
||||
///
|
||||
/// # async fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let items = vec![
|
||||
/// BatchFileItem { path: "doc1.pdf".into(), config: None },
|
||||
/// BatchFileItem { path: "doc2.pdf".into(), config: None },
|
||||
/// ];
|
||||
/// let results = batch_extract_files(items, &config).await?;
|
||||
/// println!("Processed {} files", results.len());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Per-file configuration overrides:
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::batch_extract_files;
|
||||
/// use kreuzberg::core::config::{ExtractionConfig, BatchFileItem, FileExtractionConfig};
|
||||
/// use std::path::PathBuf;
|
||||
///
|
||||
/// # async fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let items = vec![
|
||||
/// BatchFileItem {
|
||||
/// path: "scan.pdf".into(),
|
||||
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
|
||||
/// },
|
||||
/// BatchFileItem { path: "notes.txt".into(), config: None },
|
||||
/// ];
|
||||
/// let results = batch_extract_files(items, &config).await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
#[cfg_attr(feature = "otel", tracing::instrument(
|
||||
skip(config, items),
|
||||
fields(
|
||||
extraction.batch_size = items.len(),
|
||||
)
|
||||
))]
|
||||
pub async fn batch_extract_files(
|
||||
items: Vec<BatchFileItem>,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<Vec<ExtractionResult>> {
|
||||
let config_arc = Arc::new(config.clone());
|
||||
// Use Arc<Vec> for file items — paths are small, so keeping them all alive is fine.
|
||||
let items_arc = Arc::new(items);
|
||||
let count = items_arc.len();
|
||||
|
||||
collect_batch(count, config, |index, sem| {
|
||||
let cfg = Arc::clone(&config_arc);
|
||||
let items = Arc::clone(&items_arc);
|
||||
async move {
|
||||
let item = &items[index];
|
||||
let resolved = resolve_config(&cfg, &item.config);
|
||||
let timeout = resolved.extraction_timeout_secs;
|
||||
let cancel_token = resolved.cancel_token.clone();
|
||||
run_timed_extraction(index, sem, timeout, cancel_token, || {
|
||||
let path = item.path.clone();
|
||||
async move { extract_file(&path, None, &resolved).await }
|
||||
})
|
||||
.await
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
|
||||
/// Extract content from multiple byte arrays concurrently.
|
||||
///
|
||||
/// This function processes multiple byte arrays in parallel, automatically managing
|
||||
/// concurrency to prevent resource exhaustion. The concurrency limit can be
|
||||
/// configured via `ExtractionConfig::max_concurrent_extractions` or defaults
|
||||
/// to `(num_cpus * 1.5).ceil()`.
|
||||
///
|
||||
/// Each item can optionally specify a [`FileExtractionConfig`] that overrides specific
|
||||
/// fields from the batch-level `config`. Pass `None` as the config to use
|
||||
/// the batch-level defaults for that item.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `items` - Vector of `BatchBytesItem` structs, each containing content bytes,
|
||||
/// MIME type, and optional per-item configuration overrides.
|
||||
/// * `config` - Batch-level extraction configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A vector of `ExtractionResult` in the same order as the input items.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// Simple usage with no per-item overrides:
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::batch_extract_bytes;
|
||||
/// use kreuzberg::core::config::{ExtractionConfig, BatchBytesItem};
|
||||
///
|
||||
/// # async fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let items = vec![
|
||||
/// BatchBytesItem { content: b"content 1".to_vec(), mime_type: "text/plain".to_string(), config: None },
|
||||
/// BatchBytesItem { content: b"content 2".to_vec(), mime_type: "text/plain".to_string(), config: None },
|
||||
/// ];
|
||||
/// let results = batch_extract_bytes(items, &config).await?;
|
||||
/// println!("Processed {} items", results.len());
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
///
|
||||
/// Per-item configuration overrides:
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::batch_extract_bytes;
|
||||
/// use kreuzberg::core::config::{ExtractionConfig, BatchBytesItem, FileExtractionConfig};
|
||||
///
|
||||
/// # async fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let items = vec![
|
||||
/// BatchBytesItem { content: b"content".to_vec(), mime_type: "text/plain".to_string(), config: None },
|
||||
/// BatchBytesItem {
|
||||
/// content: b"<html>test</html>".to_vec(),
|
||||
/// mime_type: "text/html".to_string(),
|
||||
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
|
||||
/// },
|
||||
/// ];
|
||||
/// let results = batch_extract_bytes(items, &config).await?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
#[cfg_attr(feature = "otel", tracing::instrument(
|
||||
skip(config, items),
|
||||
fields(
|
||||
extraction.batch_size = items.len(),
|
||||
)
|
||||
))]
|
||||
pub async fn batch_extract_bytes(
|
||||
items: Vec<BatchBytesItem>,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<Vec<ExtractionResult>> {
|
||||
let config_arc = Arc::new(config.clone());
|
||||
let count = items.len();
|
||||
|
||||
// Move items into individually-indexed slots so each task can take ownership
|
||||
// of its bytes without cloning. This avoids the memory regression of
|
||||
// Arc<Vec<BatchBytesItem>> which would keep all byte arrays alive for the
|
||||
// entire batch duration.
|
||||
type BytesSlot = parking_lot::Mutex<Option<BatchBytesItem>>;
|
||||
let slots: Arc<Vec<BytesSlot>> = Arc::new(
|
||||
items
|
||||
.into_iter()
|
||||
.map(|item| parking_lot::Mutex::new(Some(item)))
|
||||
.collect(),
|
||||
);
|
||||
|
||||
collect_batch(count, config, |index, sem| {
|
||||
let cfg = Arc::clone(&config_arc);
|
||||
let slots = Arc::clone(&slots);
|
||||
async move {
|
||||
let item = slots[index].lock().take().expect("batch item already consumed");
|
||||
let resolved = resolve_config(&cfg, &item.config);
|
||||
let timeout = resolved.extraction_timeout_secs;
|
||||
let cancel_token = resolved.cancel_token.clone();
|
||||
run_timed_extraction(index, sem, timeout, cancel_token, || async move {
|
||||
extract_bytes(&item.content, &item.mime_type, &resolved).await
|
||||
})
|
||||
.await
|
||||
}
|
||||
})
|
||||
.await
|
||||
}
|
||||
170
crates/kreuzberg/src/core/extractor/bytes.rs
Normal file
170
crates/kreuzberg/src/core/extractor/bytes.rs
Normal file
@@ -0,0 +1,170 @@
|
||||
//! Byte array extraction operations.
|
||||
//!
|
||||
//! This module handles extraction from in-memory byte arrays, including:
|
||||
//! - MIME type validation
|
||||
//! - Legacy format conversion (DOC, PPT)
|
||||
//! - Extraction pipeline orchestration
|
||||
|
||||
#[cfg(not(feature = "office"))]
|
||||
use crate::KreuzbergError;
|
||||
use crate::Result;
|
||||
use crate::core::config::ExtractionConfig;
|
||||
use crate::core::mime::{LEGACY_POWERPOINT_MIME_TYPE, LEGACY_WORD_MIME_TYPE};
|
||||
use crate::types::ExtractionResult;
|
||||
|
||||
use super::file::extract_bytes_with_extractor;
|
||||
|
||||
/// Extract content from a byte array.
|
||||
///
|
||||
/// This is the main entry point for in-memory extraction. It performs the following steps:
|
||||
/// 1. Validate MIME type
|
||||
/// 2. Handle legacy format conversion if needed
|
||||
/// 3. Select appropriate extractor from registry
|
||||
/// 4. Extract content
|
||||
/// 5. Run post-processing pipeline
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `content` - The byte array to extract
|
||||
/// * `mime_type` - MIME type of the content
|
||||
/// * `config` - Extraction configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An `ExtractionResult` containing the extracted content and metadata.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if MIME type is invalid.
|
||||
/// Returns `KreuzbergError::UnsupportedFormat` if MIME type is not supported.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::extract_bytes;
|
||||
/// use kreuzberg::core::config::ExtractionConfig;
|
||||
///
|
||||
/// # async fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let bytes = b"Hello, world!";
|
||||
/// let result = extract_bytes(bytes, "text/plain", &config).await?;
|
||||
/// println!("Content: {}", result.content);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg_attr(feature = "otel", tracing::instrument(
|
||||
skip(config, content),
|
||||
fields(
|
||||
{ crate::telemetry::conventions::OPERATION } = crate::telemetry::conventions::operations::EXTRACT_BYTES,
|
||||
{ crate::telemetry::conventions::DOCUMENT_MIME_TYPE } = mime_type,
|
||||
{ crate::telemetry::conventions::DOCUMENT_SIZE_BYTES } = content.len(),
|
||||
{ crate::telemetry::conventions::OTEL_STATUS_CODE } = tracing::field::Empty,
|
||||
{ crate::telemetry::conventions::ERROR_TYPE } = tracing::field::Empty,
|
||||
{ crate::telemetry::conventions::ERROR_MESSAGE } = tracing::field::Empty,
|
||||
)
|
||||
))]
|
||||
pub async fn extract_bytes(content: &[u8], mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
|
||||
use crate::core::mime;
|
||||
|
||||
let extraction_future = async {
|
||||
if config.force_ocr && config.effective_disable_ocr() {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: "force_ocr and disable_ocr cannot both be true".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
let validated_mime = if mime_type == "application/octet-stream" {
|
||||
// When tree-sitter is configured, check if content is recognized source code.
|
||||
// This allows octet-stream files with tree-sitter config to be detected as code.
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
{
|
||||
if config.tree_sitter.is_some() {
|
||||
if let Ok(text) = std::str::from_utf8(content) {
|
||||
let trimmed = text.trim_start();
|
||||
if tree_sitter_language_pack::detect_language_from_content(trimmed).is_some() {
|
||||
// Recognize as source code when tree-sitter can detect a language.
|
||||
mime::SOURCE_CODE_MIME_TYPE.to_string()
|
||||
} else {
|
||||
mime::detect_mime_type_from_bytes(content)?
|
||||
}
|
||||
} else {
|
||||
mime::detect_mime_type_from_bytes(content)?
|
||||
}
|
||||
} else {
|
||||
mime::detect_mime_type_from_bytes(content)?
|
||||
}
|
||||
}
|
||||
#[cfg(not(feature = "tree-sitter"))]
|
||||
{
|
||||
let _ = config;
|
||||
mime::detect_mime_type_from_bytes(content)?
|
||||
}
|
||||
} else {
|
||||
mime::validate_mime_type(mime_type)?
|
||||
};
|
||||
|
||||
// Native DOC/PPT extractors are registered in the plugin registry.
|
||||
// When the office feature is disabled, these MIME types are unsupported.
|
||||
#[cfg(not(feature = "office"))]
|
||||
match validated_mime.as_str() {
|
||||
LEGACY_WORD_MIME_TYPE => {
|
||||
return Err(KreuzbergError::UnsupportedFormat(
|
||||
"Legacy Word extraction requires the `office` feature".to_string(),
|
||||
));
|
||||
}
|
||||
LEGACY_POWERPOINT_MIME_TYPE => {
|
||||
return Err(KreuzbergError::UnsupportedFormat(
|
||||
"Legacy PowerPoint extraction requires the `office` feature".to_string(),
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Suppress unused import warnings when office feature is enabled
|
||||
#[cfg(feature = "office")]
|
||||
{
|
||||
let _ = LEGACY_WORD_MIME_TYPE;
|
||||
let _ = LEGACY_POWERPOINT_MIME_TYPE;
|
||||
}
|
||||
|
||||
extract_bytes_with_extractor(content, &validated_mime, config).await
|
||||
};
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
let result = if let Some(secs) = config.extraction_timeout_secs {
|
||||
let start = std::time::Instant::now();
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(secs), extraction_future).await {
|
||||
Ok(inner) => inner,
|
||||
Err(_elapsed) => {
|
||||
if let Some(ref token) = config.cancel_token {
|
||||
token.cancel();
|
||||
}
|
||||
Err(crate::KreuzbergError::Timeout {
|
||||
elapsed_ms: start.elapsed().as_millis() as u64,
|
||||
limit_ms: secs * 1000,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
extraction_future.await
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "tokio-runtime"))]
|
||||
let result = {
|
||||
if config.extraction_timeout_secs.is_some() {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: "extraction_timeout_secs requires the 'tokio-runtime' feature to be enabled".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
extraction_future.await
|
||||
};
|
||||
|
||||
#[cfg(feature = "otel")]
|
||||
if let Err(ref e) = result {
|
||||
crate::telemetry::spans::record_error_on_current_span(e);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
279
crates/kreuzberg/src/core/extractor/file.rs
Normal file
279
crates/kreuzberg/src/core/extractor/file.rs
Normal file
@@ -0,0 +1,279 @@
|
||||
//! File-based extraction operations.
|
||||
//!
|
||||
//! This module handles extraction from filesystem paths, including:
|
||||
//! - MIME type detection and validation
|
||||
//! - Legacy format conversion (DOC, PPT)
|
||||
//! - File validation and reading
|
||||
//! - Extraction pipeline orchestration
|
||||
|
||||
#[cfg(not(feature = "office"))]
|
||||
use crate::KreuzbergError;
|
||||
use crate::Result;
|
||||
use crate::core::config::ExtractionConfig;
|
||||
use crate::core::mime::{LEGACY_POWERPOINT_MIME_TYPE, LEGACY_WORD_MIME_TYPE};
|
||||
use crate::types::ExtractionResult;
|
||||
use std::path::Path;
|
||||
|
||||
use super::helpers::get_extractor;
|
||||
|
||||
/// Extract content from a file.
|
||||
///
|
||||
/// This is the main entry point for file-based extraction. It performs the following steps:
|
||||
/// 1. Check cache for existing result (if caching enabled)
|
||||
/// 2. Detect or validate MIME type
|
||||
/// 3. Select appropriate extractor from registry
|
||||
/// 4. Extract content
|
||||
/// 5. Run post-processing pipeline
|
||||
/// 6. Store result in cache (if caching enabled)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the file to extract
|
||||
/// * `mime_type` - Optional MIME type override. If None, will be auto-detected
|
||||
/// * `config` - Extraction configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An `ExtractionResult` containing the extracted content and metadata.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Io` if the file doesn't exist (NotFound) or for other file I/O errors.
|
||||
/// Returns `KreuzbergError::UnsupportedFormat` if MIME type is not supported.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::extract_file;
|
||||
/// use kreuzberg::core::config::ExtractionConfig;
|
||||
///
|
||||
/// # async fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let result = extract_file("document.pdf", None, &config).await?;
|
||||
/// println!("Content: {}", result.content);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg_attr(feature = "otel", tracing::instrument(
|
||||
skip(config, path),
|
||||
fields(
|
||||
{ crate::telemetry::conventions::OPERATION } = crate::telemetry::conventions::operations::EXTRACT_FILE,
|
||||
{ crate::telemetry::conventions::DOCUMENT_FILENAME } = tracing::field::Empty,
|
||||
{ crate::telemetry::conventions::OTEL_STATUS_CODE } = tracing::field::Empty,
|
||||
{ crate::telemetry::conventions::ERROR_TYPE } = tracing::field::Empty,
|
||||
{ crate::telemetry::conventions::ERROR_MESSAGE } = tracing::field::Empty,
|
||||
)
|
||||
))]
|
||||
pub async fn extract_file(
|
||||
path: impl AsRef<Path>,
|
||||
mime_type: Option<&str>,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<ExtractionResult> {
|
||||
use crate::core::{io, mime};
|
||||
|
||||
let path = path.as_ref();
|
||||
|
||||
#[cfg(feature = "otel")]
|
||||
{
|
||||
let span = tracing::Span::current();
|
||||
span.record(
|
||||
crate::telemetry::conventions::DOCUMENT_FILENAME,
|
||||
crate::telemetry::spans::sanitize_path(path),
|
||||
);
|
||||
}
|
||||
|
||||
let extraction_future = async {
|
||||
io::validate_file_exists(path)?;
|
||||
|
||||
if config.force_ocr && config.effective_disable_ocr() {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: "force_ocr and disable_ocr cannot both be true".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
let detected_mime = mime::detect_or_validate(path.to_str(), mime_type)?;
|
||||
|
||||
// Native DOC/PPT extractors are registered in the plugin registry.
|
||||
// When the office feature is disabled, these MIME types are unsupported.
|
||||
#[cfg(not(feature = "office"))]
|
||||
match detected_mime.as_str() {
|
||||
LEGACY_WORD_MIME_TYPE => {
|
||||
return Err(KreuzbergError::UnsupportedFormat(
|
||||
"Legacy Word extraction requires the `office` feature".to_string(),
|
||||
));
|
||||
}
|
||||
LEGACY_POWERPOINT_MIME_TYPE => {
|
||||
return Err(KreuzbergError::UnsupportedFormat(
|
||||
"Legacy PowerPoint extraction requires the `office` feature".to_string(),
|
||||
));
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
// Suppress unused import warnings when office feature is enabled
|
||||
#[cfg(feature = "office")]
|
||||
{
|
||||
let _ = LEGACY_WORD_MIME_TYPE;
|
||||
let _ = LEGACY_POWERPOINT_MIME_TYPE;
|
||||
}
|
||||
|
||||
extract_file_with_extractor(path, &detected_mime, config).await
|
||||
};
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
let result = if let Some(secs) = config.extraction_timeout_secs {
|
||||
let start = std::time::Instant::now();
|
||||
match tokio::time::timeout(std::time::Duration::from_secs(secs), extraction_future).await {
|
||||
Ok(inner) => inner,
|
||||
Err(_elapsed) => {
|
||||
if let Some(ref token) = config.cancel_token {
|
||||
token.cancel();
|
||||
}
|
||||
Err(crate::KreuzbergError::Timeout {
|
||||
elapsed_ms: start.elapsed().as_millis() as u64,
|
||||
limit_ms: secs * 1000,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
extraction_future.await
|
||||
};
|
||||
|
||||
#[cfg(not(feature = "tokio-runtime"))]
|
||||
let result = {
|
||||
if config.extraction_timeout_secs.is_some() {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: "extraction_timeout_secs requires the 'tokio-runtime' feature to be enabled".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
extraction_future.await
|
||||
};
|
||||
|
||||
#[cfg(feature = "otel")]
|
||||
if let Err(ref e) = result {
|
||||
crate::telemetry::spans::record_error_on_current_span(e);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
pub(in crate::core::extractor) async fn extract_file_with_extractor(
|
||||
path: &Path,
|
||||
mime_type: &str,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<ExtractionResult> {
|
||||
// Normalize config so cache keys are consistent for ElementBased requests
|
||||
// regardless of whether the caller explicitly set extract_pages.
|
||||
let config = config.normalized();
|
||||
let config = config.as_ref();
|
||||
|
||||
// Skip cache if disabled or TTL=0
|
||||
if !config.use_cache || config.cache_ttl_secs == Some(0) {
|
||||
return extract_file_uncached(path, mime_type, config).await;
|
||||
}
|
||||
|
||||
// Generate cache key from file content hash + config fingerprint
|
||||
let content_hash = crate::cache::blake3_hash_file(path)?;
|
||||
let config_hash = hash_extraction_config(config, mime_type);
|
||||
let cache_key = format!("{content_hash}_{config_hash}");
|
||||
|
||||
let namespace = config.cache_namespace.as_deref();
|
||||
|
||||
// Try cache read
|
||||
if let Some(cache) = get_extraction_cache()
|
||||
&& let Ok(Some(data)) = cache.get(&cache_key, path.to_str(), namespace, config.cache_ttl_secs)
|
||||
&& let Ok(result) = rmp_serde::from_slice::<ExtractionResult>(&data)
|
||||
{
|
||||
tracing::debug!(cache_key = %cache_key, "Extraction cache hit");
|
||||
return Ok(result);
|
||||
}
|
||||
|
||||
// Cache miss — extract
|
||||
let result = extract_file_uncached(path, mime_type, config).await?;
|
||||
|
||||
// Cache write (best-effort)
|
||||
if let Some(cache) = get_extraction_cache()
|
||||
&& let Ok(data) = rmp_serde::to_vec(&result)
|
||||
{
|
||||
let _ = cache.set(&cache_key, data, path.to_str(), namespace, config.cache_ttl_secs);
|
||||
}
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Extract without caching logic.
|
||||
async fn extract_file_uncached(path: &Path, mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
|
||||
let budget = crate::core::config::concurrency::resolve_thread_budget(config.concurrency.as_ref());
|
||||
crate::core::config::concurrency::init_thread_pools(budget);
|
||||
|
||||
crate::extractors::ensure_initialized()?;
|
||||
|
||||
let extractor = get_extractor(mime_type)?;
|
||||
let doc = extractor.extract_file(path, mime_type, config).await?;
|
||||
let result = crate::core::pipeline::run_pipeline(doc, config).await?;
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Hash ExtractionConfig fields that affect extraction output.
|
||||
///
|
||||
/// Excludes cache-control fields (use_cache, cache_namespace, cache_ttl_secs)
|
||||
/// since they don't affect the extraction result. Uses a clone-and-normalize
|
||||
/// approach to ensure determinism: cache fields are zeroed, then the struct
|
||||
/// is serialized to canonical JSON via serde_json's sorted-keys representation.
|
||||
fn hash_extraction_config(config: &ExtractionConfig, mime_type: &str) -> String {
|
||||
let mut normalized = config.clone();
|
||||
// Zero out cache-control fields so they don't affect the hash
|
||||
normalized.use_cache = true;
|
||||
normalized.cache_namespace = None;
|
||||
normalized.cache_ttl_secs = None;
|
||||
|
||||
let mut hasher = blake3::Hasher::new();
|
||||
hasher.update(mime_type.as_bytes());
|
||||
// Use MessagePack for deterministic serialization (no float formatting issues,
|
||||
// no HashMap key ordering issues — serde serializes struct fields in declaration order).
|
||||
if let Ok(bytes) = rmp_serde::to_vec(&normalized) {
|
||||
hasher.update(&bytes);
|
||||
}
|
||||
let hash = hasher.finalize();
|
||||
hex::encode(&hash.as_bytes()[..16])
|
||||
}
|
||||
|
||||
/// Get or initialize the global extraction cache.
|
||||
fn get_extraction_cache() -> Option<&'static crate::cache::GenericCache> {
|
||||
use std::sync::OnceLock;
|
||||
static CACHE: OnceLock<Option<crate::cache::GenericCache>> = OnceLock::new();
|
||||
|
||||
CACHE
|
||||
.get_or_init(|| {
|
||||
crate::cache::GenericCache::new(
|
||||
"extraction".to_string(),
|
||||
None,
|
||||
30.0, // 30-day default TTL
|
||||
2000.0, // 2 GB max cache size
|
||||
500.0, // 500 MB min free space
|
||||
)
|
||||
.ok()
|
||||
})
|
||||
.as_ref()
|
||||
}
|
||||
|
||||
pub(in crate::core::extractor) async fn extract_bytes_with_extractor(
|
||||
content: &[u8],
|
||||
mime_type: &str,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<ExtractionResult> {
|
||||
let config = config.normalized();
|
||||
let config = config.as_ref();
|
||||
|
||||
let budget = crate::core::config::concurrency::resolve_thread_budget(config.concurrency.as_ref());
|
||||
crate::core::config::concurrency::init_thread_pools(budget);
|
||||
|
||||
crate::extractors::ensure_initialized()?;
|
||||
|
||||
let extractor = get_extractor(mime_type)?;
|
||||
let doc = extractor.extract_bytes(content, mime_type, config).await?;
|
||||
let result = crate::core::pipeline::run_pipeline(doc, config).await?;
|
||||
Ok(result)
|
||||
}
|
||||
85
crates/kreuzberg/src/core/extractor/helpers.rs
Normal file
85
crates/kreuzberg/src/core/extractor/helpers.rs
Normal file
@@ -0,0 +1,85 @@
|
||||
//! Helper functions and utilities for extraction operations.
|
||||
//!
|
||||
//! This module provides shared utilities used across extraction modules.
|
||||
|
||||
use crate::plugins::DocumentExtractor;
|
||||
use crate::types::{ErrorMetadata, ExtractionResult, Metadata};
|
||||
use crate::{KreuzbergError, Result};
|
||||
use std::borrow::Cow;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Get an extractor from the registry.
|
||||
///
|
||||
/// This function acquires the registry read lock and retrieves the appropriate
|
||||
/// extractor for the given MIME type.
|
||||
///
|
||||
/// When the `otel` feature is enabled, the returned extractor is wrapped in an
|
||||
/// [`InstrumentedExtractor`](crate::plugins::extractor::instrumented::InstrumentedExtractor)
|
||||
/// that adds tracing spans and metrics automatically.
|
||||
///
|
||||
/// # Performance
|
||||
///
|
||||
/// RwLock read + HashMap lookup is ~100ns, fast enough without caching.
|
||||
/// Removed thread-local cache to avoid Tokio work-stealing scheduler issues.
|
||||
pub(in crate::core::extractor) fn get_extractor(mime_type: &str) -> Result<Arc<dyn DocumentExtractor>> {
|
||||
let registry = crate::plugins::registry::get_document_extractor_registry();
|
||||
let registry_read = registry.read();
|
||||
let extractor = registry_read.get(mime_type)?;
|
||||
|
||||
#[cfg(feature = "otel")]
|
||||
{
|
||||
Ok(Arc::new(
|
||||
crate::plugins::extractor::instrumented::InstrumentedExtractor::new(extractor),
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "otel"))]
|
||||
{
|
||||
Ok(extractor)
|
||||
}
|
||||
}
|
||||
|
||||
/// Get optimal pool sizing hint for a document.
|
||||
///
|
||||
/// This function calculates recommended pool sizes based on the document's
|
||||
/// file size and MIME type. The hint can be used to create appropriately
|
||||
/// sized thread pools for extraction, reducing memory waste from over-allocation.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `file_size` - The size of the file in bytes
|
||||
/// * `mime_type` - The MIME type of the document
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A `PoolSizeHint` with recommended pool configurations
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,ignore
|
||||
/// use kreuzberg::core::extractor::get_pool_sizing_hint;
|
||||
///
|
||||
/// let hint = get_pool_sizing_hint(5_000_000, "application/pdf");
|
||||
/// println!("Recommended string buffers: {}", hint.string_buffer_count);
|
||||
/// ```
|
||||
/// Build an error `ExtractionResult` for failed batch items.
|
||||
///
|
||||
/// Used by both tokio-based batch functions and WASM synchronous fallbacks
|
||||
/// to construct a uniform error result.
|
||||
pub(crate) fn error_extraction_result(e: &KreuzbergError, elapsed_ms: Option<u64>) -> ExtractionResult {
|
||||
let metadata = Metadata {
|
||||
error: Some(ErrorMetadata {
|
||||
error_type: format!("{:?}", e),
|
||||
message: e.to_string(),
|
||||
}),
|
||||
extraction_duration_ms: elapsed_ms,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
ExtractionResult {
|
||||
content: format!("Error: {}", e),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
metadata,
|
||||
..Default::default()
|
||||
}
|
||||
}
|
||||
67
crates/kreuzberg/src/core/extractor/legacy.rs
Normal file
67
crates/kreuzberg/src/core/extractor/legacy.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Legacy synchronous extraction for WASM compatibility.
|
||||
//!
|
||||
//! This module provides truly synchronous extraction implementations
|
||||
//! for environments where Tokio runtime is not available (e.g., WASM).
|
||||
|
||||
/// Synchronous extraction implementation for WASM compatibility.
|
||||
///
|
||||
/// This function performs extraction without requiring a tokio runtime.
|
||||
/// It calls the sync extractor methods directly.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `content` - The byte content to extract
|
||||
/// * `mime_type` - Optional MIME type to validate/use
|
||||
/// * `config` - Optional extraction configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// An `ExtractionResult` or a `KreuzbergError`
|
||||
///
|
||||
/// # Implementation Notes
|
||||
///
|
||||
/// This is called when the `tokio-runtime` feature is disabled.
|
||||
/// It replicates the logic of `extract_bytes` but uses synchronous extractor methods.
|
||||
#[cfg(not(feature = "tokio-runtime"))]
|
||||
pub(super) fn extract_bytes_sync_impl(
|
||||
content: &[u8],
|
||||
mime_type: Option<&str>,
|
||||
config: Option<&crate::core::config::ExtractionConfig>,
|
||||
) -> crate::Result<crate::types::ExtractionResult> {
|
||||
use crate::KreuzbergError;
|
||||
use crate::core::extractor::helpers::get_extractor;
|
||||
use crate::core::mime;
|
||||
|
||||
let cfg = config.cloned().unwrap_or_default();
|
||||
let cfg = cfg.normalized().into_owned();
|
||||
|
||||
let validated_mime = if let Some(mime) = mime_type {
|
||||
if mime == "application/octet-stream" {
|
||||
mime::detect_mime_type_from_bytes(content)?
|
||||
} else {
|
||||
mime::validate_mime_type(mime)?
|
||||
}
|
||||
} else {
|
||||
return Err(KreuzbergError::Validation {
|
||||
message: "MIME type is required for synchronous extraction".to_string(),
|
||||
source: None,
|
||||
});
|
||||
};
|
||||
|
||||
crate::extractors::ensure_initialized()?;
|
||||
|
||||
let extractor = get_extractor(&validated_mime)?;
|
||||
|
||||
let sync_extractor = extractor.as_sync_extractor().ok_or_else(|| {
|
||||
KreuzbergError::UnsupportedFormat(format!(
|
||||
"Extractor for '{}' does not support synchronous extraction",
|
||||
validated_mime
|
||||
))
|
||||
})?;
|
||||
|
||||
let doc = sync_extractor.extract_sync(content, &validated_mime, &cfg)?;
|
||||
|
||||
let result = crate::core::pipeline::run_pipeline_sync(doc, &cfg)?;
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
681
crates/kreuzberg/src/core/extractor/mod.rs
Normal file
681
crates/kreuzberg/src/core/extractor/mod.rs
Normal file
@@ -0,0 +1,681 @@
|
||||
//! Main extraction entry points.
|
||||
//!
|
||||
//! This module provides the primary API for extracting content from files and byte arrays.
|
||||
//! It orchestrates the entire extraction pipeline: cache checking, MIME detection,
|
||||
//! extractor selection, extraction, post-processing, and cache storage.
|
||||
//!
|
||||
//! # Functions
|
||||
//!
|
||||
//! - [`extract_file`] - Extract content from a file path
|
||||
//! - [`extract_bytes`] - Extract content from a byte array
|
||||
//! - [`batch_extract_files`] - Extract content from multiple files concurrently
|
||||
//! - [`batch_extract_bytes`] - Extract content from multiple byte arrays concurrently
|
||||
|
||||
mod bytes;
|
||||
mod file;
|
||||
mod helpers;
|
||||
mod legacy;
|
||||
mod sync;
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
mod batch;
|
||||
|
||||
// Re-export public API
|
||||
pub use bytes::extract_bytes;
|
||||
pub use file::extract_file;
|
||||
pub use sync::{batch_extract_bytes_sync, extract_bytes_sync};
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub use sync::extract_file_sync;
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub use batch::{batch_extract_bytes, batch_extract_files};
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub use sync::batch_extract_files_sync;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::config::{BatchBytesItem, BatchFileItem, ExtractionConfig};
|
||||
use serial_test::serial;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use std::sync::Arc;
|
||||
use tempfile::tempdir;
|
||||
|
||||
fn assert_text_content(actual: &str, expected: &str) {
|
||||
assert_eq!(actual.trim_end_matches('\n'), expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_basic() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
let mut file = File::create(&file_path).unwrap();
|
||||
file.write_all(b"Hello, world!").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, None, &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
assert_text_content(&result.content, "Hello, world!");
|
||||
assert_eq!(result.mime_type, "text/plain");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_with_mime_override() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.dat");
|
||||
let mut file = File::create(&file_path).unwrap();
|
||||
file.write_all(b"test content").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, Some("text/plain"), &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
assert_eq!(result.mime_type, "text/plain");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_nonexistent() {
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file("/nonexistent/file.txt", None, &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_bytes_basic() {
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_bytes(b"test content", "text/plain", &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
assert_text_content(&result.content, "test content");
|
||||
assert_eq!(result.mime_type, "text/plain");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_bytes_invalid_mime() {
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_bytes(b"test", "invalid/mime", &config).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
let file1 = dir.path().join("test1.txt");
|
||||
let file2 = dir.path().join("test2.txt");
|
||||
|
||||
File::create(&file1).unwrap().write_all(b"content 1").unwrap();
|
||||
File::create(&file2).unwrap().write_all(b"content 2").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchFileItem {
|
||||
path: file1,
|
||||
config: None,
|
||||
},
|
||||
BatchFileItem {
|
||||
path: file2,
|
||||
config: None,
|
||||
},
|
||||
];
|
||||
let results = batch_extract_files(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_text_content(&results[0].content, "content 1");
|
||||
assert_text_content(&results[1].content, "content 2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_file_empty() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items: Vec<BatchFileItem> = vec![];
|
||||
let results = batch_extract_files(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
assert_eq!(results.unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_bytes() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchBytesItem {
|
||||
content: b"content 1".to_vec(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
config: None,
|
||||
},
|
||||
BatchBytesItem {
|
||||
content: b"content 2".to_vec(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
config: None,
|
||||
},
|
||||
];
|
||||
let results = batch_extract_bytes(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_text_content(&results[0].content, "content 1");
|
||||
assert_text_content(&results[1].content, "content 2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_wrappers() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
File::create(&file_path).unwrap().write_all(b"sync test").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
|
||||
let result = extract_file_sync(&file_path, None, &config);
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
assert_text_content(&result.content, "sync test");
|
||||
|
||||
let result = extract_bytes_sync(b"test", "text/plain", &config);
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extractor_cache() {
|
||||
let config = ExtractionConfig::default();
|
||||
|
||||
let result1 = extract_bytes(b"test 1", "text/plain", &config).await;
|
||||
assert!(result1.is_ok());
|
||||
let result1 = result1.unwrap();
|
||||
|
||||
let result2 = extract_bytes(b"test 2", "text/plain", &config).await;
|
||||
assert!(result2.is_ok());
|
||||
let result2 = result2.unwrap();
|
||||
|
||||
assert_text_content(&result1.content, "test 1");
|
||||
assert_text_content(&result2.content, "test 2");
|
||||
|
||||
let result3 = extract_bytes(b"# test 3", "text/markdown", &config).await;
|
||||
assert!(result3.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_empty() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("empty.txt");
|
||||
File::create(&file_path).unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, None, &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
assert_eq!(result.content, "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_bytes_empty() {
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_bytes(b"", "text/plain", &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
assert_eq!(result.content, "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_whitespace_only() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("whitespace.txt");
|
||||
File::create(&file_path).unwrap().write_all(b" \n\t \n ").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, None, &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_very_long_path() {
|
||||
let dir = tempdir().unwrap();
|
||||
let long_name = "a".repeat(200);
|
||||
let file_path = dir.path().join(format!("{}.txt", long_name));
|
||||
|
||||
if let Ok(mut f) = File::create(&file_path) {
|
||||
f.write_all(b"content").unwrap();
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, None, &config).await;
|
||||
assert!(result.is_ok() || result.is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_special_characters_in_path() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test with spaces & symbols!.txt");
|
||||
File::create(&file_path).unwrap().write_all(b"content").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, None, &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
assert_text_content(&result.content, "content");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_unicode_filename() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("测试文件名.txt");
|
||||
File::create(&file_path).unwrap().write_all(b"content").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, None, &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_bytes_unsupported_mime() {
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_bytes(b"test", "application/x-unknown-format", &config).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
use crate::KreuzbergError;
|
||||
assert!(matches!(result.unwrap_err(), KreuzbergError::UnsupportedFormat(_)));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_file_with_errors() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
let valid_file = dir.path().join("valid.txt");
|
||||
File::create(&valid_file).unwrap().write_all(b"valid content").unwrap();
|
||||
|
||||
let invalid_file = dir.path().join("nonexistent.txt");
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchFileItem {
|
||||
path: valid_file,
|
||||
config: None,
|
||||
},
|
||||
BatchFileItem {
|
||||
path: invalid_file,
|
||||
config: None,
|
||||
},
|
||||
];
|
||||
let results = batch_extract_files(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_text_content(&results[0].content, "valid content");
|
||||
assert!(results[1].metadata.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_bytes_mixed_valid_invalid() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchBytesItem {
|
||||
content: b"valid 1".to_vec(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
config: None,
|
||||
},
|
||||
BatchBytesItem {
|
||||
content: b"invalid".to_vec(),
|
||||
mime_type: "invalid/mime".to_string(),
|
||||
config: None,
|
||||
},
|
||||
BatchBytesItem {
|
||||
content: b"valid 2".to_vec(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
config: None,
|
||||
},
|
||||
];
|
||||
let results = batch_extract_bytes(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 3);
|
||||
assert_text_content(&results[0].content, "valid 1");
|
||||
assert!(results[1].metadata.error.is_some());
|
||||
assert_text_content(&results[2].content, "valid 2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_bytes_all_invalid() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchBytesItem {
|
||||
content: b"test 1".to_vec(),
|
||||
mime_type: "invalid/mime1".to_string(),
|
||||
config: None,
|
||||
},
|
||||
BatchBytesItem {
|
||||
content: b"test 2".to_vec(),
|
||||
mime_type: "invalid/mime2".to_string(),
|
||||
config: None,
|
||||
},
|
||||
];
|
||||
let results = batch_extract_bytes(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results[0].metadata.error.is_some());
|
||||
assert!(results[1].metadata.error.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_bytes_very_large() {
|
||||
let large_content = vec![b'a'; 10_000_000];
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_bytes(&large_content, "text/plain", &config).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let result = result.unwrap();
|
||||
let trimmed_len = result.content.trim_end_matches('\n').len();
|
||||
assert_eq!(trimmed_len, 10_000_000);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_large_count() {
|
||||
let dir = tempdir().unwrap();
|
||||
let mut items = Vec::new();
|
||||
|
||||
for i in 0..100 {
|
||||
let file_path = dir.path().join(format!("file{}.txt", i));
|
||||
File::create(&file_path)
|
||||
.unwrap()
|
||||
.write_all(format!("content {}", i).as_bytes())
|
||||
.unwrap();
|
||||
items.push(BatchFileItem {
|
||||
path: file_path,
|
||||
config: None,
|
||||
});
|
||||
}
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let results = batch_extract_files(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 100);
|
||||
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
assert_text_content(&result.content, &format!("content {}", i));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_mime_detection_fallback() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("testfile");
|
||||
File::create(&file_path)
|
||||
.unwrap()
|
||||
.write_all(b"plain text content")
|
||||
.unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, None, &config).await;
|
||||
|
||||
assert!(result.is_ok() || result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_file_wrong_mime_override() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
File::create(&file_path).unwrap().write_all(b"plain text").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file(&file_path, Some("application/pdf"), &config).await;
|
||||
|
||||
assert!(result.is_err() || result.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_wrapper_nonexistent_file() {
|
||||
let config = ExtractionConfig::default();
|
||||
let result = extract_file_sync("/nonexistent/path.txt", None, &config);
|
||||
|
||||
assert!(result.is_err());
|
||||
use crate::KreuzbergError;
|
||||
// File validation returns Io error, not Validation error
|
||||
assert!(matches!(result.unwrap_err(), KreuzbergError::Io { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_wrapper_batch_empty() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items: Vec<BatchFileItem> = vec![];
|
||||
let results = batch_extract_files_sync(items, &config);
|
||||
|
||||
assert!(results.is_ok());
|
||||
assert_eq!(results.unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sync_wrapper_batch_bytes_empty() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items: Vec<BatchBytesItem> = vec![];
|
||||
let results = batch_extract_bytes_sync(items, &config);
|
||||
|
||||
assert!(results.is_ok());
|
||||
assert_eq!(results.unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_extractions_same_mime() {
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
let config = Arc::new(ExtractionConfig::default());
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
for i in 0..50 {
|
||||
let config_clone = Arc::clone(&config);
|
||||
tasks.spawn(async move {
|
||||
let content = format!("test content {}", i);
|
||||
extract_bytes(content.as_bytes(), "text/plain", &config_clone).await
|
||||
});
|
||||
}
|
||||
|
||||
let mut success_count = 0;
|
||||
while let Some(task_result) = tasks.join_next().await {
|
||||
if let Ok(Ok(_)) = task_result {
|
||||
success_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(success_count, 50);
|
||||
}
|
||||
|
||||
#[serial]
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_extractions_different_mimes() {
|
||||
use tokio::task::JoinSet;
|
||||
|
||||
let config = Arc::new(ExtractionConfig::default());
|
||||
let mut tasks = JoinSet::new();
|
||||
|
||||
let mime_types = ["text/plain", "text/markdown"];
|
||||
|
||||
for i in 0..30 {
|
||||
let config_clone = Arc::clone(&config);
|
||||
let mime = mime_types[i % mime_types.len()];
|
||||
tasks.spawn(async move {
|
||||
let content = format!("test {}", i);
|
||||
extract_bytes(content.as_bytes(), mime, &config_clone).await
|
||||
});
|
||||
}
|
||||
|
||||
let mut success_count = 0;
|
||||
while let Some(task_result) = tasks.join_next().await {
|
||||
if let Ok(Ok(_)) = task_result {
|
||||
success_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(success_count, 30);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_file_with_per_file_configs() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
let file1 = dir.path().join("test1.txt");
|
||||
let file2 = dir.path().join("test2.txt");
|
||||
File::create(&file1).unwrap().write_all(b"content 1").unwrap();
|
||||
File::create(&file2).unwrap().write_all(b"content 2").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchFileItem {
|
||||
path: file1,
|
||||
config: Some(crate::FileExtractionConfig::default()),
|
||||
},
|
||||
BatchFileItem {
|
||||
path: file2,
|
||||
config: None,
|
||||
},
|
||||
];
|
||||
let results = batch_extract_files(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_text_content(&results[0].content, "content 1");
|
||||
assert_text_content(&results[1].content, "content 2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_file_with_configs_empty() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items: Vec<BatchFileItem> = vec![];
|
||||
let results = batch_extract_files(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
assert_eq!(results.unwrap().len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_bytes_with_per_item_configs() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchBytesItem {
|
||||
content: b"hello".to_vec(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
config: None,
|
||||
},
|
||||
BatchBytesItem {
|
||||
content: b"world".to_vec(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
config: Some(crate::FileExtractionConfig::default()),
|
||||
},
|
||||
];
|
||||
let results = batch_extract_bytes(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_text_content(&results[0].content, "hello");
|
||||
assert_text_content(&results[1].content, "world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_extract_bytes_with_configs_error_handling() {
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![
|
||||
BatchBytesItem {
|
||||
content: b"valid".to_vec(),
|
||||
mime_type: "text/plain".to_string(),
|
||||
config: None,
|
||||
},
|
||||
BatchBytesItem {
|
||||
content: b"invalid".to_vec(),
|
||||
mime_type: "invalid/mime".to_string(),
|
||||
config: Some(crate::FileExtractionConfig::default()),
|
||||
},
|
||||
];
|
||||
let results = batch_extract_bytes(items, &config).await;
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert_text_content(&results[0].content, "valid");
|
||||
assert!(results[1].metadata.error.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_extract_file_sync_with_configs() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
File::create(&file_path).unwrap().write_all(b"sync test").unwrap();
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
let items = vec![BatchFileItem {
|
||||
path: file_path,
|
||||
config: None,
|
||||
}];
|
||||
let results = batch_extract_files_sync(items, &config);
|
||||
|
||||
assert!(results.is_ok());
|
||||
let results = results.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_text_content(&results[0].content, "sync test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_file_overrides_single_field() {
|
||||
let base = ExtractionConfig::default();
|
||||
assert!(!base.force_ocr);
|
||||
|
||||
let overrides = crate::FileExtractionConfig {
|
||||
force_ocr: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = base.with_file_overrides(&overrides);
|
||||
assert!(resolved.force_ocr);
|
||||
// Other fields unchanged
|
||||
assert_eq!(resolved.use_cache, base.use_cache);
|
||||
assert_eq!(resolved.enable_quality_processing, base.enable_quality_processing);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_file_overrides_none_keeps_default() {
|
||||
let base = ExtractionConfig::default();
|
||||
let overrides = crate::FileExtractionConfig::default(); // all None
|
||||
let resolved = base.with_file_overrides(&overrides);
|
||||
// All fields should match base
|
||||
assert_eq!(resolved.use_cache, base.use_cache);
|
||||
assert_eq!(resolved.force_ocr, base.force_ocr);
|
||||
assert_eq!(resolved.enable_quality_processing, base.enable_quality_processing);
|
||||
assert_eq!(resolved.include_document_structure, base.include_document_structure);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_with_file_overrides_batch_fields_unaffected() {
|
||||
let base = ExtractionConfig {
|
||||
max_concurrent_extractions: Some(42),
|
||||
use_cache: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let overrides = crate::FileExtractionConfig {
|
||||
force_ocr: Some(true),
|
||||
..Default::default()
|
||||
};
|
||||
let resolved = base.with_file_overrides(&overrides);
|
||||
// Batch-level fields must be preserved from base
|
||||
assert_eq!(resolved.max_concurrent_extractions, Some(42));
|
||||
assert!(!resolved.use_cache);
|
||||
// Override applied
|
||||
assert!(resolved.force_ocr);
|
||||
}
|
||||
}
|
||||
200
crates/kreuzberg/src/core/extractor/sync.rs
Normal file
200
crates/kreuzberg/src/core/extractor/sync.rs
Normal file
@@ -0,0 +1,200 @@
|
||||
//! Synchronous wrappers for extraction operations.
|
||||
//!
|
||||
//! This module provides blocking synchronous wrappers around async extraction functions
|
||||
//! for use in non-async contexts. Uses a global Tokio runtime for optimal performance.
|
||||
|
||||
use crate::Result;
|
||||
use crate::core::config::BatchBytesItem;
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use crate::core::config::BatchFileItem;
|
||||
use crate::core::config::ExtractionConfig;
|
||||
use crate::types::ExtractionResult;
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use std::path::Path;
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use once_cell::sync::OnceCell;
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use super::batch::{batch_extract_bytes, batch_extract_files};
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use super::bytes::extract_bytes;
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
use super::file::extract_file;
|
||||
|
||||
#[cfg(not(feature = "tokio-runtime"))]
|
||||
use super::helpers::error_extraction_result;
|
||||
|
||||
/// Global Tokio runtime cell for synchronous operations.
|
||||
///
|
||||
/// Lazily initialized on first use and shared across all sync wrappers.
|
||||
/// Using a global runtime instead of creating one per call provides 100x+ performance improvement.
|
||||
///
|
||||
/// # Availability
|
||||
///
|
||||
/// This static is only available when the `tokio-runtime` feature is enabled.
|
||||
/// For WASM targets, use the truly synchronous extraction functions instead.
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
static GLOBAL_RUNTIME: OnceCell<tokio::runtime::Runtime> = OnceCell::new();
|
||||
|
||||
/// Returns a reference to the global Tokio runtime, initializing it on first call.
|
||||
///
|
||||
/// Returns an error if the runtime cannot be created (e.g. system resource exhaustion).
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
fn global_runtime() -> crate::Result<&'static tokio::runtime::Runtime> {
|
||||
GLOBAL_RUNTIME.get_or_try_init(|| {
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.map_err(|e| crate::KreuzbergError::Plugin {
|
||||
message: format!("Failed to create global Tokio runtime: {e}"),
|
||||
plugin_name: "runtime".to_string(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Synchronous wrapper for `extract_file`.
|
||||
///
|
||||
/// This is a convenience function that blocks the current thread until extraction completes.
|
||||
/// For async code, use `extract_file` directly.
|
||||
///
|
||||
/// Uses the global Tokio runtime for 100x+ performance improvement over creating
|
||||
/// a new runtime per call. Always uses the global runtime to avoid nested runtime issues.
|
||||
///
|
||||
/// This function is only available with the `tokio-runtime` feature. For WASM targets,
|
||||
/// use a truly synchronous extraction approach instead.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::extract_file_sync;
|
||||
/// use kreuzberg::core::config::ExtractionConfig;
|
||||
///
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let result = extract_file_sync("document.pdf", None, &config)?;
|
||||
/// println!("Content: {}", result.content);
|
||||
/// # Ok::<(), kreuzberg::KreuzbergError>(())
|
||||
/// ```
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub fn extract_file_sync(
|
||||
path: impl AsRef<Path>,
|
||||
mime_type: Option<&str>,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<ExtractionResult> {
|
||||
global_runtime()?.block_on(extract_file(path, mime_type, config))
|
||||
}
|
||||
|
||||
/// Synchronous wrapper for `extract_bytes`.
|
||||
///
|
||||
/// Uses the global Tokio runtime for 100x+ performance improvement over creating
|
||||
/// a new runtime per call.
|
||||
///
|
||||
/// With the `tokio-runtime` feature, this blocks the current thread using the global
|
||||
/// Tokio runtime. Without it (WASM), this calls a truly synchronous implementation.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::extract_bytes_sync;
|
||||
/// use kreuzberg::core::config::ExtractionConfig;
|
||||
///
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let bytes = b"Hello, world!";
|
||||
/// let result = extract_bytes_sync(bytes, "text/plain", &config)?;
|
||||
/// println!("Content: {}", result.content);
|
||||
/// # Ok::<(), kreuzberg::KreuzbergError>(())
|
||||
/// ```
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub fn extract_bytes_sync(content: &[u8], mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
|
||||
global_runtime()?.block_on(extract_bytes(content, mime_type, config))
|
||||
}
|
||||
|
||||
/// Synchronous wrapper for `extract_bytes` (WASM-compatible version).
|
||||
///
|
||||
/// This is a truly synchronous implementation without tokio runtime dependency.
|
||||
/// It calls `extract_bytes_sync_impl()` to perform the extraction.
|
||||
#[cfg(not(feature = "tokio-runtime"))]
|
||||
pub fn extract_bytes_sync(content: &[u8], mime_type: &str, config: &ExtractionConfig) -> Result<ExtractionResult> {
|
||||
super::legacy::extract_bytes_sync_impl(content, Some(mime_type), Some(config))
|
||||
}
|
||||
|
||||
/// Synchronous wrapper for `batch_extract_files`.
|
||||
///
|
||||
/// Uses the global Tokio runtime for optimal performance.
|
||||
/// Only available with `tokio-runtime` (WASM has no filesystem).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::batch_extract_files_sync;
|
||||
/// use kreuzberg::core::config::{ExtractionConfig, BatchFileItem, FileExtractionConfig};
|
||||
///
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let items = vec![
|
||||
/// BatchFileItem {
|
||||
/// path: "doc1.pdf".into(),
|
||||
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
|
||||
/// },
|
||||
/// BatchFileItem { path: "doc2.pdf".into(), config: None },
|
||||
/// ];
|
||||
/// let results = batch_extract_files_sync(items, &config)?;
|
||||
/// # Ok::<(), kreuzberg::KreuzbergError>(())
|
||||
/// ```
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub fn batch_extract_files_sync(items: Vec<BatchFileItem>, config: &ExtractionConfig) -> Result<Vec<ExtractionResult>> {
|
||||
global_runtime()?.block_on(batch_extract_files(items, config))
|
||||
}
|
||||
|
||||
/// Synchronous wrapper for `batch_extract_bytes`.
|
||||
///
|
||||
/// Uses the global Tokio runtime for optimal performance.
|
||||
/// With the `tokio-runtime` feature, this blocks the current thread using the global
|
||||
/// Tokio runtime. Without it (WASM), this calls a truly synchronous implementation
|
||||
/// that iterates through items and calls `extract_bytes_sync()`.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::extractor::batch_extract_bytes_sync;
|
||||
/// use kreuzberg::core::config::{ExtractionConfig, BatchBytesItem, FileExtractionConfig};
|
||||
///
|
||||
/// let config = ExtractionConfig::default();
|
||||
/// let items = vec![
|
||||
/// BatchBytesItem { content: b"content".to_vec(), mime_type: "text/plain".to_string(), config: None },
|
||||
/// BatchBytesItem {
|
||||
/// content: b"other".to_vec(),
|
||||
/// mime_type: "text/plain".to_string(),
|
||||
/// config: Some(FileExtractionConfig { force_ocr: Some(true), ..Default::default() }),
|
||||
/// },
|
||||
/// ];
|
||||
/// let results = batch_extract_bytes_sync(items, &config)?;
|
||||
/// # Ok::<(), kreuzberg::KreuzbergError>(())
|
||||
/// ```
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub fn batch_extract_bytes_sync(
|
||||
items: Vec<BatchBytesItem>,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<Vec<ExtractionResult>> {
|
||||
global_runtime()?.block_on(batch_extract_bytes(items, config))
|
||||
}
|
||||
|
||||
/// Synchronous wrapper for `batch_extract_bytes` (WASM-compatible version).
|
||||
///
|
||||
/// Iterates through items sequentially, applying per-file config overrides.
|
||||
#[cfg(not(feature = "tokio-runtime"))]
|
||||
pub fn batch_extract_bytes_sync(
|
||||
items: Vec<BatchBytesItem>,
|
||||
config: &ExtractionConfig,
|
||||
) -> Result<Vec<ExtractionResult>> {
|
||||
let mut results = Vec::with_capacity(items.len());
|
||||
for item in items {
|
||||
let resolved = match &item.config {
|
||||
Some(fc) => config.with_file_overrides(fc),
|
||||
None => config.clone(),
|
||||
};
|
||||
let result = extract_bytes_sync(&item.content, &item.mime_type, &resolved);
|
||||
results.push(result.unwrap_or_else(|e| error_extraction_result(&e, None)));
|
||||
}
|
||||
Ok(results)
|
||||
}
|
||||
237
crates/kreuzberg/src/core/formats.rs
Normal file
237
crates/kreuzberg/src/core/formats.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
//! Format field validation and metadata.
|
||||
//!
|
||||
//! This module provides a compile-time registry of known format fields used across
|
||||
//! document extractors. It serves as the single source of truth for format validation
|
||||
//! across all language bindings (Rust, Python, TypeScript, Ruby, Java, Go).
|
||||
//!
|
||||
//! # Known Format Fields
|
||||
//!
|
||||
//! The registry contains 55 standardized format fields organized by category:
|
||||
//! - **Document Properties**: title, author, keywords, creator, producer, etc.
|
||||
//! - **Dates**: creation_date, modification_date
|
||||
//! - **Pagination**: page_count
|
||||
//! - **Email Metadata**: from_email, from_name, to_emails, cc_emails, bcc_emails, message_id
|
||||
//! - **Attachments**: attachments
|
||||
//! - **Descriptions**: description, summary
|
||||
//! - **Typography**: fonts
|
||||
//! - **Archive/Compression**: format, file_count, file_list, total_size, compressed_size
|
||||
//! - **Images**: width, height
|
||||
//! - **Content Metrics**: element_count, unique_elements, line_count, word_count, character_count
|
||||
//! - **HTML Structure**: headers, links, code_blocks
|
||||
//! - **Meta Tags**: canonical, base_href, og_*, twitter_*, link_*
|
||||
//! - **OCR**: psm, output_format
|
||||
//! - **Tables**: table_count, table_rows, table_cols
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```ignore
|
||||
//! use kreuzberg::core::formats::{KNOWN_FORMATS, is_valid_format_field};
|
||||
//!
|
||||
//! assert!(is_valid_format_field("title"));
|
||||
//! assert!(!is_valid_format_field("invalid_field"));
|
||||
//! assert_eq!(KNOWN_FORMATS.len(), 55);
|
||||
//! ```
|
||||
|
||||
#[cfg(test)]
|
||||
use ahash::AHashSet;
|
||||
#[cfg(test)]
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// All known format field names across all extractors.
|
||||
///
|
||||
/// This is a compile-time constant array of standardized field names used by document
|
||||
/// extractors. Each binding (Python, TypeScript, Ruby, Java, Go) should reference this
|
||||
/// as the single source of truth for format field validation.
|
||||
///
|
||||
/// Format fields are organized by document type:
|
||||
/// - PDF/Office: title, author, creation_date, page_count, etc.
|
||||
/// - Email: from_email, to_emails, cc_emails, bcc_emails, etc.
|
||||
/// - Web: og_title, twitter_card, canonical, headers, links, etc.
|
||||
/// - Images: width, height, format
|
||||
/// - Archives: file_count, file_list, total_size, etc.
|
||||
#[cfg(test)]
|
||||
pub(crate) const KNOWN_FORMATS: &[&str] = &[
|
||||
"title",
|
||||
"author",
|
||||
"keywords",
|
||||
"creator",
|
||||
"producer",
|
||||
"creation_date",
|
||||
"modification_date",
|
||||
"page_count",
|
||||
"from_email",
|
||||
"from_name",
|
||||
"to_emails",
|
||||
"cc_emails",
|
||||
"bcc_emails",
|
||||
"message_id",
|
||||
"attachments",
|
||||
"description",
|
||||
"summary",
|
||||
"fonts",
|
||||
"format",
|
||||
"file_count",
|
||||
"file_list",
|
||||
"total_size",
|
||||
"compressed_size",
|
||||
"width",
|
||||
"height",
|
||||
"element_count",
|
||||
"unique_elements",
|
||||
"line_count",
|
||||
"word_count",
|
||||
"character_count",
|
||||
"headers",
|
||||
"links",
|
||||
"code_blocks",
|
||||
"canonical",
|
||||
"base_href",
|
||||
"og_title",
|
||||
"og_description",
|
||||
"og_image",
|
||||
"og_url",
|
||||
"og_type",
|
||||
"og_site_name",
|
||||
"twitter_card",
|
||||
"twitter_title",
|
||||
"twitter_description",
|
||||
"twitter_image",
|
||||
"twitter_site",
|
||||
"twitter_creator",
|
||||
"link_author",
|
||||
"link_license",
|
||||
"link_alternate",
|
||||
"psm",
|
||||
"output_format",
|
||||
"table_count",
|
||||
"table_rows",
|
||||
"table_cols",
|
||||
];
|
||||
|
||||
/// Cached format field set for fast O(1) lookups.
|
||||
///
|
||||
/// Uses AHashSet for its excellent cache locality and performance characteristics
|
||||
/// with string keys. Built lazily on first use with minimal overhead.
|
||||
#[cfg(test)]
|
||||
static FORMAT_FIELD_SET: LazyLock<AHashSet<&'static str>> = LazyLock::new(|| KNOWN_FORMATS.iter().copied().collect());
|
||||
|
||||
/// Validates whether a field name is in the known formats registry.
|
||||
///
|
||||
/// This uses a pre-built hash set for O(1) lookups instead of linear search,
|
||||
/// providing significant performance improvements for repeated validations.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `field` - The field name to validate
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if the field is in KNOWN_FORMATS, `false` otherwise.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::formats::is_valid_format_field;
|
||||
///
|
||||
/// assert!(is_valid_format_field("title"));
|
||||
/// assert!(is_valid_format_field("creation_date"));
|
||||
/// assert!(!is_valid_format_field("invalid_field"));
|
||||
/// ```
|
||||
#[cfg(test)]
|
||||
#[inline]
|
||||
pub(crate) fn is_valid_format_field(field: &str) -> bool {
|
||||
FORMAT_FIELD_SET.contains(field)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_known_formats_count() {
|
||||
assert_eq!(KNOWN_FORMATS.len(), 55, "Expected 55 known format fields");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_known_formats_no_duplicates() {
|
||||
let mut seen = std::collections::HashSet::new();
|
||||
for field in KNOWN_FORMATS {
|
||||
assert!(seen.insert(field), "Duplicate format field found: {}", field);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_format_field_true_cases() {
|
||||
assert!(is_valid_format_field("title"));
|
||||
assert!(is_valid_format_field("author"));
|
||||
assert!(is_valid_format_field("creation_date"));
|
||||
assert!(is_valid_format_field("page_count"));
|
||||
assert!(is_valid_format_field("from_email"));
|
||||
assert!(is_valid_format_field("og_title"));
|
||||
assert!(is_valid_format_field("twitter_card"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_valid_format_field_false_cases() {
|
||||
assert!(!is_valid_format_field("invalid_field"));
|
||||
assert!(!is_valid_format_field("unknown_metadata"));
|
||||
assert!(!is_valid_format_field(""));
|
||||
assert!(!is_valid_format_field("TITLE"));
|
||||
assert!(!is_valid_format_field("title "));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_document_property_fields() {
|
||||
let doc_fields = ["title", "author", "keywords", "creator", "producer"];
|
||||
for field in &doc_fields {
|
||||
assert!(is_valid_format_field(field), "Missing field: {}", field);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_email_fields() {
|
||||
let email_fields = [
|
||||
"from_email",
|
||||
"from_name",
|
||||
"to_emails",
|
||||
"cc_emails",
|
||||
"bcc_emails",
|
||||
"message_id",
|
||||
"attachments",
|
||||
];
|
||||
for field in &email_fields {
|
||||
assert!(is_valid_format_field(field), "Missing email field: {}", field);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_web_meta_fields() {
|
||||
let web_fields = [
|
||||
"og_title",
|
||||
"og_description",
|
||||
"og_image",
|
||||
"og_url",
|
||||
"og_type",
|
||||
"og_site_name",
|
||||
"twitter_card",
|
||||
"twitter_title",
|
||||
"twitter_description",
|
||||
"twitter_image",
|
||||
"twitter_site",
|
||||
"twitter_creator",
|
||||
"canonical",
|
||||
"base_href",
|
||||
];
|
||||
for field in &web_fields {
|
||||
assert!(is_valid_format_field(field), "Missing web field: {}", field);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_all_table_fields() {
|
||||
let table_fields = ["table_count", "table_rows", "table_cols"];
|
||||
for field in &table_fields {
|
||||
assert!(is_valid_format_field(field), "Missing table field: {}", field);
|
||||
}
|
||||
}
|
||||
}
|
||||
421
crates/kreuzberg/src/core/io.rs
Normal file
421
crates/kreuzberg/src/core/io.rs
Normal file
@@ -0,0 +1,421 @@
|
||||
//! File I/O utilities.
|
||||
//!
|
||||
//! This module provides async and sync file reading utilities with proper error handling.
|
||||
//! For large files (> 1 MiB) on non-WASM platforms, memory-mapped I/O is used to avoid
|
||||
//! heap-allocating the entire file contents, reducing memory pressure and syscall overhead.
|
||||
|
||||
use crate::{KreuzbergError, Result};
|
||||
use std::path::Path;
|
||||
|
||||
/// Size threshold above which memory-mapped I/O is preferred over `read()`.
|
||||
///
|
||||
/// Files smaller than this are read with a regular `read()` call since the
|
||||
/// mmap overhead (open, fstat, mmap syscalls + TLB pressure) outweighs the
|
||||
/// benefit for small allocations.
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
const MMAP_THRESHOLD_BYTES: u64 = 1_048_576; // 1 MiB
|
||||
|
||||
/// An owned buffer of file bytes.
|
||||
///
|
||||
/// On non-WASM platforms this may be backed by a memory-mapped file (zero heap
|
||||
/// allocation for the file contents) or by a `Vec<u8>` for small files.
|
||||
/// On WASM it is always a `Vec<u8>`.
|
||||
///
|
||||
/// Implements `Deref<Target = [u8]>` so callers can pass `&FileBytes` as `&[u8]`
|
||||
/// without any additional copy.
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub struct FileBytes {
|
||||
inner: FileBytesInner,
|
||||
}
|
||||
|
||||
enum FileBytesInner {
|
||||
/// Regular heap-allocated buffer (small files or WASM).
|
||||
Heap(Vec<u8>),
|
||||
/// Memory-mapped file (large files on native platforms).
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
Mapped(memmap2::Mmap),
|
||||
}
|
||||
|
||||
impl std::ops::Deref for FileBytes {
|
||||
type Target = [u8];
|
||||
|
||||
fn deref(&self) -> &[u8] {
|
||||
match &self.inner {
|
||||
FileBytesInner::Heap(v) => v.as_slice(),
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
FileBytesInner::Mapped(m) => m.as_ref(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl AsRef<[u8]> for FileBytes {
|
||||
fn as_ref(&self) -> &[u8] {
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Open a file and return its bytes with zero-copy for large files.
|
||||
///
|
||||
/// On non-WASM targets, files larger than [`MMAP_THRESHOLD_BYTES`] are
|
||||
/// memory-mapped so that the file contents are never copied to the heap.
|
||||
/// The mapping is read-only; the file must not be modified while the returned
|
||||
/// [`FileBytes`] is alive, which is safe for document extraction.
|
||||
///
|
||||
/// On WASM or for small files, falls back to a plain `std::fs::read`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Io` for any I/O failure.
|
||||
#[allow(unsafe_code)]
|
||||
pub(crate) fn open_file_bytes(path: &Path) -> Result<FileBytes> {
|
||||
#[cfg(not(target_arch = "wasm32"))]
|
||||
{
|
||||
let metadata = std::fs::metadata(path).map_err(KreuzbergError::Io)?;
|
||||
if metadata.len() > MMAP_THRESHOLD_BYTES {
|
||||
let file = std::fs::File::open(path).map_err(KreuzbergError::Io)?;
|
||||
// SAFETY: The file is opened read-only and we do not write to the
|
||||
// mapped region. The `FileBytes` value owns the `Mmap` handle and
|
||||
// the mapping is live for exactly as long as the bytes are accessed.
|
||||
// External modification of the file while mapped is a documented
|
||||
// TOCTOU risk inherent to mmap on all platforms; it is acceptable
|
||||
// here because kreuzberg only reads user-supplied documents and
|
||||
// makes no correctness guarantees about files modified concurrently.
|
||||
let mmap = unsafe { memmap2::Mmap::map(&file) }.map_err(KreuzbergError::Io)?;
|
||||
return Ok(FileBytes {
|
||||
inner: FileBytesInner::Mapped(mmap),
|
||||
});
|
||||
}
|
||||
}
|
||||
// Small file or WASM: regular heap read.
|
||||
let bytes = std::fs::read(path).map_err(KreuzbergError::Io)?;
|
||||
Ok(FileBytes {
|
||||
inner: FileBytesInner::Heap(bytes),
|
||||
})
|
||||
}
|
||||
|
||||
/// Read a file asynchronously.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the file to read
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The file contents as bytes.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Io` for I/O errors (these always bubble up).
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub(crate) async fn read_file_async(path: impl AsRef<Path>) -> Result<Vec<u8>> {
|
||||
tokio::fs::read(path.as_ref()).await.map_err(KreuzbergError::Io)
|
||||
}
|
||||
|
||||
/// Read a file synchronously.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the file to read
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The file contents as bytes.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Io` for I/O errors (these always bubble up).
|
||||
#[cfg(test)]
|
||||
pub(crate) fn read_file_sync(path: impl AsRef<Path>) -> Result<Vec<u8>> {
|
||||
std::fs::read(path.as_ref()).map_err(KreuzbergError::Io)
|
||||
}
|
||||
|
||||
/// Check if a file exists.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to check
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// `true` if the file exists, `false` otherwise.
|
||||
pub(crate) fn file_exists(path: impl AsRef<Path>) -> bool {
|
||||
path.as_ref().exists()
|
||||
}
|
||||
|
||||
/// Validate that a file exists.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to validate
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Io` if file doesn't exist.
|
||||
pub(crate) fn validate_file_exists(path: impl AsRef<Path>) -> Result<()> {
|
||||
if !file_exists(&path) {
|
||||
return Err(KreuzbergError::from(std::io::Error::new(
|
||||
std::io::ErrorKind::NotFound,
|
||||
format!("File does not exist: {}", path.as_ref().display()),
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Traverse a directory and return all file paths matching a pattern.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dir` - Directory to traverse
|
||||
/// * `recursive` - Whether to recursively traverse subdirectories
|
||||
/// * `filter` - Optional filter function to match files
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of file paths that match the criteria.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Io` for I/O errors.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn traverse_directory<F>(
|
||||
dir: impl AsRef<Path>,
|
||||
recursive: bool,
|
||||
filter: Option<F>,
|
||||
) -> Result<Vec<std::path::PathBuf>>
|
||||
where
|
||||
F: Fn(&Path) -> bool,
|
||||
{
|
||||
let dir = dir.as_ref();
|
||||
let mut files = Vec::new();
|
||||
|
||||
if !dir.is_dir() {
|
||||
return Err(KreuzbergError::from(std::io::Error::new(
|
||||
std::io::ErrorKind::NotADirectory,
|
||||
format!("Path is not a directory: {}", dir.display()),
|
||||
)));
|
||||
}
|
||||
|
||||
traverse_directory_impl(dir, recursive, &filter, &mut files)?;
|
||||
Ok(files)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn traverse_directory_impl<F>(
|
||||
dir: &Path,
|
||||
recursive: bool,
|
||||
filter: &Option<F>,
|
||||
files: &mut Vec<std::path::PathBuf>,
|
||||
) -> Result<()>
|
||||
where
|
||||
F: Fn(&Path) -> bool,
|
||||
{
|
||||
let entries = std::fs::read_dir(dir).map_err(KreuzbergError::Io)?;
|
||||
|
||||
for entry in entries {
|
||||
let entry = entry.map_err(KreuzbergError::Io)?;
|
||||
let path = entry.path();
|
||||
|
||||
if path.is_file() {
|
||||
let should_include = match filter {
|
||||
Some(f) => f(&path),
|
||||
None => true,
|
||||
};
|
||||
|
||||
if should_include {
|
||||
files.push(path);
|
||||
}
|
||||
} else if path.is_dir() && recursive {
|
||||
traverse_directory_impl(&path, recursive, filter, files)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get all files in a directory with a specific extension.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `dir` - Directory to search
|
||||
/// * `extension` - File extension to match (without the dot)
|
||||
/// * `recursive` - Whether to recursively search subdirectories
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// Vector of file paths with the specified extension.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Io` for I/O errors.
|
||||
#[cfg(test)]
|
||||
pub(crate) fn find_files_by_extension(
|
||||
dir: impl AsRef<Path>,
|
||||
extension: &str,
|
||||
recursive: bool,
|
||||
) -> Result<Vec<std::path::PathBuf>> {
|
||||
let ext = extension.to_lowercase();
|
||||
traverse_directory(
|
||||
dir,
|
||||
recursive,
|
||||
Some(|path: &Path| {
|
||||
path.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| e.to_lowercase() == ext)
|
||||
.unwrap_or(false)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::fs::File;
|
||||
use std::io::Write;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
#[tokio::test]
|
||||
async fn test_read_file_async() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
let mut file = File::create(&file_path).unwrap();
|
||||
file.write_all(b"test content").unwrap();
|
||||
|
||||
let content = read_file_async(&file_path).await.unwrap();
|
||||
assert_eq!(content, b"test content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_file_sync() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
let mut file = File::create(&file_path).unwrap();
|
||||
file.write_all(b"test content").unwrap();
|
||||
|
||||
let content = read_file_sync(&file_path).unwrap();
|
||||
assert_eq!(content, b"test content");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_exists() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
File::create(&file_path).unwrap();
|
||||
|
||||
assert!(file_exists(&file_path));
|
||||
assert!(!file_exists(dir.path().join("nonexistent.txt")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_file_exists() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
File::create(&file_path).unwrap();
|
||||
|
||||
assert!(validate_file_exists(&file_path).is_ok());
|
||||
assert!(validate_file_exists(dir.path().join("nonexistent.txt")).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_traverse_directory_non_recursive() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
File::create(dir.path().join("file1.txt")).unwrap();
|
||||
File::create(dir.path().join("file2.pdf")).unwrap();
|
||||
File::create(dir.path().join("file3.txt")).unwrap();
|
||||
|
||||
std::fs::create_dir(dir.path().join("subdir")).unwrap();
|
||||
File::create(dir.path().join("subdir").join("file4.txt")).unwrap();
|
||||
|
||||
let files = traverse_directory(dir.path(), false, None::<fn(&Path) -> bool>).unwrap();
|
||||
assert_eq!(files.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_traverse_directory_recursive() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
File::create(dir.path().join("file1.txt")).unwrap();
|
||||
File::create(dir.path().join("file2.pdf")).unwrap();
|
||||
|
||||
std::fs::create_dir(dir.path().join("subdir")).unwrap();
|
||||
File::create(dir.path().join("subdir").join("file3.txt")).unwrap();
|
||||
File::create(dir.path().join("subdir").join("file4.pdf")).unwrap();
|
||||
|
||||
let files = traverse_directory(dir.path(), true, None::<fn(&Path) -> bool>).unwrap();
|
||||
assert_eq!(files.len(), 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_traverse_directory_with_filter() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
File::create(dir.path().join("file1.txt")).unwrap();
|
||||
File::create(dir.path().join("file2.pdf")).unwrap();
|
||||
File::create(dir.path().join("file3.txt")).unwrap();
|
||||
|
||||
let files = traverse_directory(
|
||||
dir.path(),
|
||||
false,
|
||||
Some(|path: &Path| {
|
||||
path.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|e| e == "txt")
|
||||
.unwrap_or(false)
|
||||
}),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(files.len(), 2);
|
||||
assert!(files.iter().all(|p| p.extension().unwrap() == "txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_find_files_by_extension() {
|
||||
let dir = tempdir().unwrap();
|
||||
|
||||
File::create(dir.path().join("file1.txt")).unwrap();
|
||||
File::create(dir.path().join("file2.pdf")).unwrap();
|
||||
File::create(dir.path().join("file3.TXT")).unwrap();
|
||||
|
||||
std::fs::create_dir(dir.path().join("subdir")).unwrap();
|
||||
File::create(dir.path().join("subdir").join("file4.txt")).unwrap();
|
||||
|
||||
let files = find_files_by_extension(dir.path(), "txt", false).unwrap();
|
||||
assert_eq!(files.len(), 2);
|
||||
|
||||
let files_recursive = find_files_by_extension(dir.path(), "txt", true).unwrap();
|
||||
assert_eq!(files_recursive.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_traverse_directory_invalid_path() {
|
||||
let result = traverse_directory("/nonexistent/directory", false, None::<fn(&Path) -> bool>);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_traverse_directory_file_not_dir() {
|
||||
let dir = tempdir().unwrap();
|
||||
let file_path = dir.path().join("test.txt");
|
||||
File::create(&file_path).unwrap();
|
||||
|
||||
let result = traverse_directory(&file_path, false, None::<fn(&Path) -> bool>);
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
#[tokio::test]
|
||||
async fn test_read_file_async_io_error() {
|
||||
let result = read_file_async("/nonexistent/file.txt").await;
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), KreuzbergError::Io(_)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_read_file_sync_io_error() {
|
||||
let result = read_file_sync("/nonexistent/file.txt");
|
||||
assert!(result.is_err());
|
||||
assert!(matches!(result.unwrap_err(), KreuzbergError::Io(_)));
|
||||
}
|
||||
}
|
||||
1266
crates/kreuzberg/src/core/mime.rs
Normal file
1266
crates/kreuzberg/src/core/mime.rs
Normal file
File diff suppressed because it is too large
Load Diff
61
crates/kreuzberg/src/core/mod.rs
Normal file
61
crates/kreuzberg/src/core/mod.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
//! Core extraction orchestration module.
|
||||
//!
|
||||
//! This module contains the main extraction logic and orchestration layer for Kreuzberg.
|
||||
//! It provides the primary entry points for file and bytes extraction, manages the
|
||||
//! extractor registry, MIME type detection, configuration, and post-processing pipeline.
|
||||
//!
|
||||
//! # Architecture
|
||||
//!
|
||||
//! The core module is responsible for:
|
||||
//! - **Entry Points**: Main `extract_file()` and `extract_bytes()` functions
|
||||
//! - **Registry**: Mapping MIME types to extractors with priority-based selection
|
||||
//! - **MIME Detection**: Detecting and validating MIME types from files and extensions
|
||||
//! - **Pipeline**: Orchestrating post-processing steps (chunking, quality, etc.)
|
||||
//! - **Configuration**: Loading and managing extraction configuration
|
||||
//! - **I/O**: File reading and validation utilities
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use kreuzberg::core::extractor::extract_file;
|
||||
//! use kreuzberg::core::config::ExtractionConfig;
|
||||
//!
|
||||
//! # async fn example() -> kreuzberg::Result<()> {
|
||||
//! let config = ExtractionConfig::default();
|
||||
//! let result = extract_file("document.pdf", None, &config).await?;
|
||||
//! println!("Extracted content: {}", result.content);
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub mod batch_mode;
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub mod batch_optimizations;
|
||||
pub mod config;
|
||||
pub mod config_validation;
|
||||
pub mod extractor;
|
||||
pub mod formats;
|
||||
pub mod io;
|
||||
pub mod mime;
|
||||
pub(crate) mod path_resolver;
|
||||
pub mod pipeline;
|
||||
#[cfg(feature = "api-types")]
|
||||
pub mod server_config;
|
||||
|
||||
#[cfg(feature = "pdf")]
|
||||
pub use config::HierarchyConfig;
|
||||
pub use config::{
|
||||
ChunkingConfig, EmbeddingConfig, EmbeddingModelType, ExtractionConfig, ImageExtractionConfig,
|
||||
LanguageDetectionConfig, OcrConfig, OutputFormat, PageConfig, PostProcessorConfig, TokenReductionOptions,
|
||||
};
|
||||
#[cfg(feature = "api-types")]
|
||||
pub use server_config::ServerConfig;
|
||||
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub use batch_optimizations::{BatchProcessor, BatchProcessorConfig};
|
||||
#[cfg(feature = "pdf")]
|
||||
pub use config::PdfConfig;
|
||||
#[cfg(feature = "tokio-runtime")]
|
||||
pub use extractor::{batch_extract_bytes, batch_extract_files};
|
||||
pub use extractor::{extract_bytes, extract_file};
|
||||
282
crates/kreuzberg/src/core/path_resolver.rs
Normal file
282
crates/kreuzberg/src/core/path_resolver.rs
Normal file
@@ -0,0 +1,282 @@
|
||||
//! Image path resolution for markup extractors.
|
||||
//!
|
||||
//! Resolves relative image paths found in markup documents (Markdown, LaTeX, RST,
|
||||
//! Org-mode, Typst, Djot, DocBook) to actual filesystem paths, reads the image data,
|
||||
//! and attaches them to the extraction result.
|
||||
|
||||
use std::borrow::Cow;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use bytes::Bytes;
|
||||
|
||||
use crate::core::config::ExtractionConfig;
|
||||
use crate::types::ExtractedImage;
|
||||
use crate::types::internal::InternalDocument;
|
||||
use crate::types::uri::UriKind;
|
||||
|
||||
/// Maximum image file size: 50 MB.
|
||||
const MAX_IMAGE_SIZE: u64 = 50 * 1024 * 1024;
|
||||
|
||||
/// Resolve a relative image reference against a base directory.
|
||||
///
|
||||
/// Returns `None` for URLs, absolute paths, and paths that escape `base_dir`
|
||||
/// via traversal (`..`). Returns `Some(resolved)` for safe relative paths.
|
||||
///
|
||||
/// This function performs no filesystem access — it only validates the
|
||||
/// structural safety of the path.
|
||||
pub(crate) fn resolve_image_path(base_dir: &Path, image_ref: &str) -> Option<PathBuf> {
|
||||
let trimmed = image_ref.trim();
|
||||
|
||||
// Reject URLs
|
||||
if trimmed.starts_with("http://")
|
||||
|| trimmed.starts_with("https://")
|
||||
|| trimmed.starts_with("data:")
|
||||
|| trimmed.starts_with("ftp://")
|
||||
|| trimmed.starts_with("mailto:")
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
// Strip file:// or file: prefix (org-mode uses file: without //)
|
||||
let path_str = if let Some(stripped) = trimmed.strip_prefix("file://") {
|
||||
stripped
|
||||
} else if let Some(stripped) = trimmed.strip_prefix("file:") {
|
||||
stripped
|
||||
} else {
|
||||
trimmed
|
||||
};
|
||||
|
||||
// Reject absolute paths (Unix or Windows drive letter)
|
||||
if path_str.starts_with('/')
|
||||
|| (path_str.len() >= 2 && path_str.as_bytes()[0].is_ascii_alphabetic() && path_str.as_bytes()[1] == b':')
|
||||
{
|
||||
return None;
|
||||
}
|
||||
|
||||
let joined = base_dir.join(path_str);
|
||||
let normalized = normalize_path(&joined);
|
||||
|
||||
// Path traversal prevention: resolved path must start with base_dir
|
||||
let norm_base = normalize_path(base_dir);
|
||||
if !normalized.starts_with(&norm_base) {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(normalized)
|
||||
}
|
||||
|
||||
/// Read an image file and produce an `ExtractedImage`.
|
||||
///
|
||||
/// Returns `None` if the file does not exist, is not a regular file,
|
||||
/// exceeds the size limit, or has an unrecognised extension.
|
||||
pub(crate) fn read_image_file(path: &Path, image_index: u32) -> Option<ExtractedImage> {
|
||||
let meta = std::fs::metadata(path).ok()?;
|
||||
if !meta.is_file() {
|
||||
return None;
|
||||
}
|
||||
if meta.len() > MAX_IMAGE_SIZE {
|
||||
return None;
|
||||
}
|
||||
|
||||
let ext = path
|
||||
.extension()
|
||||
.and_then(|e| e.to_str())
|
||||
.map(|s| s.to_ascii_lowercase())?;
|
||||
|
||||
let format: Cow<'static, str> = match ext.as_str() {
|
||||
"png" => Cow::Borrowed("png"),
|
||||
"jpg" | "jpeg" => Cow::Borrowed("jpeg"),
|
||||
"gif" => Cow::Borrowed("gif"),
|
||||
"webp" => Cow::Borrowed("webp"),
|
||||
"svg" => Cow::Borrowed("svg"),
|
||||
"bmp" => Cow::Borrowed("bmp"),
|
||||
"tiff" | "tif" => Cow::Borrowed("tiff"),
|
||||
"avif" => Cow::Borrowed("avif"),
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let data = std::fs::read(path).ok()?;
|
||||
let source_path = path.to_string_lossy().into_owned();
|
||||
|
||||
Some(ExtractedImage {
|
||||
data: Bytes::from(data),
|
||||
format,
|
||||
image_index,
|
||||
page_number: None,
|
||||
width: None,
|
||||
height: None,
|
||||
colorspace: None,
|
||||
bits_per_component: None,
|
||||
is_mask: false,
|
||||
description: None,
|
||||
ocr_result: None,
|
||||
bounding_box: None,
|
||||
source_path: Some(source_path),
|
||||
image_kind: None,
|
||||
kind_confidence: None,
|
||||
cluster_id: None,
|
||||
})
|
||||
}
|
||||
|
||||
/// Resolve image URIs in an `InternalDocument` to actual image data.
|
||||
///
|
||||
/// Iterates over all `UriKind::Image` entries, resolves them relative to
|
||||
/// `base_dir`, reads the file, and appends the result to `doc.images`.
|
||||
/// No-op when image extraction is disabled in `config`.
|
||||
pub(crate) fn resolve_image_uris(doc: &mut InternalDocument, base_dir: &Path, config: &ExtractionConfig) {
|
||||
let image_extraction_enabled = config.images.as_ref().is_some_and(|img| img.extract_images);
|
||||
|
||||
if !image_extraction_enabled {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut image_index = doc.images.len() as u32;
|
||||
|
||||
// Collect URI indices first to avoid borrow conflict (doc.uris vs doc.images).
|
||||
let image_uri_indices: Vec<usize> = doc
|
||||
.uris
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter(|(_, uri)| uri.kind == UriKind::Image)
|
||||
.map(|(i, _)| i)
|
||||
.collect();
|
||||
|
||||
for idx in image_uri_indices {
|
||||
if let Some(resolved) = resolve_image_path(base_dir, &doc.uris[idx].url)
|
||||
&& let Some(img) = read_image_file(&resolved, image_index)
|
||||
{
|
||||
doc.images.push(img);
|
||||
image_index += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Read a file, extract via `extract_bytes`, then resolve image URIs.
|
||||
///
|
||||
/// Shared helper for markup extractors (Markdown, LaTeX, RST, Org-mode, Typst,
|
||||
/// Djot, DocBook, MDX) that need to resolve relative image paths after extraction.
|
||||
pub(crate) async fn extract_file_with_image_resolution(
|
||||
extractor: &(dyn crate::plugins::DocumentExtractor + Sync),
|
||||
path: &Path,
|
||||
mime_type: &str,
|
||||
config: &ExtractionConfig,
|
||||
) -> crate::Result<InternalDocument> {
|
||||
let bytes = crate::core::io::open_file_bytes(path)?;
|
||||
let mut doc = extractor.extract_bytes(&bytes, mime_type, config).await?;
|
||||
if let Some(base_dir) = path.parent() {
|
||||
resolve_image_uris(&mut doc, base_dir, config);
|
||||
}
|
||||
Ok(doc)
|
||||
}
|
||||
|
||||
/// Normalize a path by resolving `.` and `..` components without filesystem access.
|
||||
fn normalize_path(path: &Path) -> PathBuf {
|
||||
let mut components = Vec::new();
|
||||
for component in path.components() {
|
||||
match component {
|
||||
std::path::Component::ParentDir => {
|
||||
components.pop();
|
||||
}
|
||||
std::path::Component::CurDir => {}
|
||||
c => components.push(c),
|
||||
}
|
||||
}
|
||||
components.iter().collect()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::Path;
|
||||
|
||||
#[test]
|
||||
fn test_resolve_relative_path() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
let result = resolve_image_path(base, "images/photo.png");
|
||||
assert_eq!(result, Some(PathBuf::from("/home/user/docs/images/photo.png")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_nested_relative() {
|
||||
let base = Path::new("/project/content");
|
||||
let result = resolve_image_path(base, "images/subfolder/nested.png");
|
||||
assert_eq!(
|
||||
result,
|
||||
Some(PathBuf::from("/project/content/images/subfolder/nested.png"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_absolute_path() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
assert_eq!(resolve_image_path(base, "/etc/passwd"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_traversal() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
assert_eq!(resolve_image_path(base, "../../etc/passwd"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_http_url() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
assert_eq!(resolve_image_path(base, "https://example.com/img.png"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_data_uri() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
assert_eq!(resolve_image_path(base, "data:image/png;base64,abc"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nonexistent_file_still_resolves() {
|
||||
// resolve_image_path only checks structure, not filesystem
|
||||
let base = Path::new("/nonexistent/base");
|
||||
let result = resolve_image_path(base, "sub/image.jpg");
|
||||
assert_eq!(result, Some(PathBuf::from("/nonexistent/base/sub/image.jpg")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_path_with_spaces() {
|
||||
let base = Path::new("/home/user/my docs");
|
||||
let result = resolve_image_path(base, "my images/photo.png");
|
||||
assert_eq!(result, Some(PathBuf::from("/home/user/my docs/my images/photo.png")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_windows_backslash() {
|
||||
// On all platforms, std::path::Path::join handles separators correctly.
|
||||
// On Unix, backslash is a valid filename char so it stays as-is in the component.
|
||||
// The key point: the function does not panic and produces a usable path.
|
||||
let base = Path::new("/home/user/docs");
|
||||
let result = resolve_image_path(base, "images/photo.png");
|
||||
assert!(result.is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_ftp_url() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
assert_eq!(resolve_image_path(base, "ftp://server/img.png"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_mailto() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
assert_eq!(resolve_image_path(base, "mailto:user@example.com"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_file_uri_stripped() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
let result = resolve_image_path(base, "file://images/photo.png");
|
||||
assert_eq!(result, Some(PathBuf::from("/home/user/docs/images/photo.png")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reject_windows_absolute() {
|
||||
let base = Path::new("/home/user/docs");
|
||||
assert_eq!(resolve_image_path(base, "C:\\Windows\\img.png"), None);
|
||||
}
|
||||
}
|
||||
46
crates/kreuzberg/src/core/pipeline/cache.rs
Normal file
46
crates/kreuzberg/src/core/pipeline/cache.rs
Normal file
@@ -0,0 +1,46 @@
|
||||
//! Processor caching to reduce lock contention.
|
||||
//!
|
||||
//! This module manages the caching of post-processors by processing stage,
|
||||
//! eliminating repeated registry lock acquisitions.
|
||||
|
||||
use crate::Result;
|
||||
use crate::plugins::{PostProcessor, ProcessingStage};
|
||||
use parking_lot::RwLock;
|
||||
use std::sync::Arc;
|
||||
use std::sync::LazyLock;
|
||||
|
||||
/// Cached post-processors for each stage to reduce lock contention.
|
||||
///
|
||||
/// This cache is populated once during the first pipeline run and reused
|
||||
/// for all subsequent extractions, eliminating 3 of 4 registry lock acquisitions
|
||||
/// per extraction.
|
||||
pub(super) struct ProcessorCache {
|
||||
pub(super) early: Arc<Vec<Arc<dyn PostProcessor>>>,
|
||||
pub(super) middle: Arc<Vec<Arc<dyn PostProcessor>>>,
|
||||
pub(super) late: Arc<Vec<Arc<dyn PostProcessor>>>,
|
||||
}
|
||||
|
||||
impl ProcessorCache {
|
||||
/// Create a new processor cache by fetching from the registry.
|
||||
pub(super) fn new() -> Result<Self> {
|
||||
let processor_registry = crate::plugins::registry::get_post_processor_registry();
|
||||
let registry = processor_registry.read();
|
||||
|
||||
Ok(Self {
|
||||
early: Arc::new(registry.get_for_stage(ProcessingStage::Early)),
|
||||
middle: Arc::new(registry.get_for_stage(ProcessingStage::Middle)),
|
||||
late: Arc::new(registry.get_for_stage(ProcessingStage::Late)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Lazy processor cache - initialized on first use, then cached.
|
||||
pub(super) static PROCESSOR_CACHE: LazyLock<RwLock<Option<ProcessorCache>>> = LazyLock::new(|| RwLock::new(None));
|
||||
|
||||
/// Clear the processor cache (primarily for testing when registry changes).
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn clear_processor_cache() -> Result<()> {
|
||||
let mut cache = PROCESSOR_CACHE.write();
|
||||
*cache = None;
|
||||
Ok(())
|
||||
}
|
||||
128
crates/kreuzberg/src/core/pipeline/execution.rs
Normal file
128
crates/kreuzberg/src/core/pipeline/execution.rs
Normal file
@@ -0,0 +1,128 @@
|
||||
//! Core processor execution logic.
|
||||
//!
|
||||
//! This module handles the execution of post-processors and validators
|
||||
//! in the correct order.
|
||||
|
||||
use crate::core::config::ExtractionConfig;
|
||||
use crate::plugins::ProcessingStage;
|
||||
use crate::types::{ExtractionResult, ProcessingWarning};
|
||||
use crate::{KreuzbergError, Result};
|
||||
use std::borrow::Cow;
|
||||
#[cfg(feature = "otel")]
|
||||
use std::time::Instant;
|
||||
#[cfg(feature = "otel")]
|
||||
use tracing::Instrument;
|
||||
|
||||
/// Execute all registered post-processors by stage.
|
||||
pub(super) async fn execute_processors(
|
||||
result: &mut ExtractionResult,
|
||||
config: &ExtractionConfig,
|
||||
pp_config: &Option<&crate::core::config::PostProcessorConfig>,
|
||||
early_processors: std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
|
||||
middle_processors: std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
|
||||
late_processors: std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
|
||||
) -> Result<()> {
|
||||
for (_stage, processors_arc) in [
|
||||
(ProcessingStage::Early, early_processors),
|
||||
(ProcessingStage::Middle, middle_processors),
|
||||
(ProcessingStage::Late, late_processors),
|
||||
] {
|
||||
#[cfg(feature = "otel")]
|
||||
let stage_name = match _stage {
|
||||
ProcessingStage::Early => crate::telemetry::conventions::stages::POST_PROCESSING_EARLY,
|
||||
ProcessingStage::Middle => crate::telemetry::conventions::stages::POST_PROCESSING_MIDDLE,
|
||||
ProcessingStage::Late => crate::telemetry::conventions::stages::POST_PROCESSING_LATE,
|
||||
};
|
||||
#[cfg(feature = "otel")]
|
||||
let stage_span = crate::telemetry::spans::pipeline_stage_span(stage_name);
|
||||
#[cfg(feature = "otel")]
|
||||
let stage_start = Instant::now();
|
||||
#[cfg(feature = "otel")]
|
||||
let _stage_guard = stage_span.enter();
|
||||
|
||||
for processor in processors_arc.iter() {
|
||||
let processor_name = processor.name();
|
||||
|
||||
let should_run = should_processor_run(pp_config, processor_name);
|
||||
|
||||
if should_run && processor.should_process(result, config) {
|
||||
#[cfg(feature = "otel")]
|
||||
let processor_span = crate::telemetry::spans::pipeline_processor_span(stage_name, processor_name);
|
||||
|
||||
#[cfg(feature = "otel")]
|
||||
let process_result = processor.process(result, config).instrument(processor_span).await;
|
||||
#[cfg(not(feature = "otel"))]
|
||||
let process_result = processor.process(result, config).await;
|
||||
|
||||
match process_result {
|
||||
Ok(_) => {}
|
||||
Err(err @ KreuzbergError::Io(_))
|
||||
| Err(err @ KreuzbergError::LockPoisoned(_))
|
||||
| Err(err @ KreuzbergError::Plugin { .. }) => {
|
||||
return Err(err);
|
||||
}
|
||||
Err(err) => {
|
||||
let error_msg = err.to_string();
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Owned(processor_name.to_string()),
|
||||
message: Cow::Owned(error_msg),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "otel")]
|
||||
{
|
||||
let stage_ms = stage_start.elapsed().as_secs_f64() * 1000.0;
|
||||
crate::telemetry::metrics::get_metrics().pipeline_duration_ms.record(
|
||||
stage_ms,
|
||||
&[opentelemetry::KeyValue::new(
|
||||
crate::telemetry::conventions::PIPELINE_STAGE,
|
||||
stage_name.to_string(),
|
||||
)],
|
||||
);
|
||||
drop(_stage_guard);
|
||||
drop(stage_span);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Determine if a processor should run based on configuration.
|
||||
fn should_processor_run(pp_config: &Option<&crate::core::config::PostProcessorConfig>, processor_name: &str) -> bool {
|
||||
if let Some(config) = pp_config {
|
||||
if let Some(ref enabled_set) = config.enabled_set {
|
||||
enabled_set.contains(processor_name)
|
||||
} else if let Some(ref disabled_set) = config.disabled_set {
|
||||
!disabled_set.contains(processor_name)
|
||||
} else if let Some(ref enabled) = config.enabled_processors {
|
||||
enabled.iter().any(|name| name == processor_name)
|
||||
} else if let Some(ref disabled) = config.disabled_processors {
|
||||
!disabled.iter().any(|name| name == processor_name)
|
||||
} else {
|
||||
true
|
||||
}
|
||||
} else {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
/// Execute all registered validators.
|
||||
pub(super) async fn execute_validators(result: &ExtractionResult, config: &ExtractionConfig) -> Result<()> {
|
||||
let validator_registry = crate::plugins::registry::get_validator_registry();
|
||||
let validators = {
|
||||
let registry = validator_registry.read();
|
||||
registry.get_all()
|
||||
};
|
||||
|
||||
if !validators.is_empty() {
|
||||
for validator in validators {
|
||||
if validator.should_validate(result, config) {
|
||||
validator.validate(result, config).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
461
crates/kreuzberg/src/core/pipeline/features.rs
Normal file
461
crates/kreuzberg/src/core/pipeline/features.rs
Normal file
@@ -0,0 +1,461 @@
|
||||
//! Feature processing logic.
|
||||
//!
|
||||
//! This module handles feature-specific processing like chunking,
|
||||
//! embedding generation, and language detection.
|
||||
|
||||
use crate::Result;
|
||||
use crate::core::config::ExtractionConfig;
|
||||
#[cfg(feature = "chunking")]
|
||||
use crate::types::PageBoundary;
|
||||
use crate::types::{ExtractionResult, ProcessingWarning};
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// Recompute page boundaries against the rendered `content` string.
|
||||
///
|
||||
/// `PageBoundary` offsets produced during extraction are computed against raw
|
||||
/// rendered/source text, but `result.content` is produced by `render_plain` which
|
||||
/// trims trailing whitespace from each paragraph. The raw page text therefore has
|
||||
/// different byte lengths for pages that contain trailing-space artifacts from PDF
|
||||
/// rendering. This function re-derives the boundaries by locating each page's
|
||||
/// **paragraph-normalised** content (each `"\n\n"`-separated segment trimmed, then
|
||||
/// re-joined) inside the combined `content` string, so that the byte offsets passed
|
||||
/// to the chunker are valid indices into `result.content`.
|
||||
///
|
||||
/// Pages whose content cannot be found are silently skipped (the chunker will
|
||||
/// still produce output, just without page-range metadata for those pages).
|
||||
#[cfg(feature = "chunking")]
|
||||
fn recompute_boundaries_from_pages(content: &str, pages: &[crate::types::PageContent]) -> Vec<PageBoundary> {
|
||||
let mut boundaries = Vec::with_capacity(pages.len());
|
||||
let mut search_offset = 0usize;
|
||||
|
||||
for page in pages {
|
||||
if page.content.trim().is_empty() {
|
||||
boundaries.push(PageBoundary {
|
||||
page_number: page.page_number,
|
||||
byte_start: search_offset,
|
||||
byte_end: search_offset,
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
// Normalise page content to match what render_plain produces: split on the
|
||||
// paragraph separator, trim each segment (PDF pages often carry trailing
|
||||
// spaces before "\n\n" that render_plain strips via paragraph.trim()), then
|
||||
// re-join. Using the normalised form means exact-match succeeds and the
|
||||
// resulting byte_end is correct — avoiding cascading search_offset
|
||||
// over-advance that would push past subsequent pages.
|
||||
let normalized: String = page
|
||||
.content
|
||||
.split("\n\n")
|
||||
.map(str::trim)
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join("\n\n");
|
||||
|
||||
// Try normalised-exact match (primary path — handles trailing-space pages).
|
||||
if let Some(pos) = content[search_offset..].find(normalized.as_str()) {
|
||||
let byte_start = search_offset + pos;
|
||||
let byte_end = content.floor_char_boundary(byte_start + normalized.len());
|
||||
boundaries.push(PageBoundary {
|
||||
page_number: page.page_number,
|
||||
byte_start,
|
||||
byte_end,
|
||||
});
|
||||
search_offset = byte_end;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Fallback: search for first non-empty line of page content.
|
||||
// Use normalized.len() for byte_end so search_offset advances correctly.
|
||||
if let Some(line) = page.content.lines().find(|l| !l.trim().is_empty()).map(|l| l.trim())
|
||||
&& let Some(pos) = content[search_offset..].find(line)
|
||||
{
|
||||
let byte_start = search_offset + pos;
|
||||
let raw_end = (byte_start + normalized.len()).min(content.len());
|
||||
let byte_end = content.floor_char_boundary(raw_end);
|
||||
boundaries.push(PageBoundary {
|
||||
page_number: page.page_number,
|
||||
byte_start,
|
||||
byte_end,
|
||||
});
|
||||
search_offset = byte_end;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Last resort: skip this page
|
||||
tracing::debug!(
|
||||
page = page.page_number,
|
||||
"Could not locate page content in rendered text — skipping page boundary"
|
||||
);
|
||||
}
|
||||
|
||||
boundaries
|
||||
}
|
||||
|
||||
/// Map TSLP `CodeChunk`s directly to kreuzberg `Chunk`s, bypassing text-splitter.
|
||||
///
|
||||
/// When the extraction result contains code intelligence with non-empty chunks,
|
||||
/// those chunks already represent semantically meaningful code boundaries produced
|
||||
/// by tree-sitter. Using text-splitter would break these boundaries.
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
fn try_code_chunks(result: &ExtractionResult) -> Option<Vec<crate::types::extraction::Chunk>> {
|
||||
use crate::types::extraction::{Chunk, ChunkMetadata, ChunkType, HeadingContext, HeadingLevel};
|
||||
|
||||
let code_chunks = match &result.metadata.format {
|
||||
Some(crate::types::metadata::FormatMetadata::Code(pr)) if !pr.chunks.is_empty() => &pr.chunks,
|
||||
_ => return None,
|
||||
};
|
||||
|
||||
let total_chunks = code_chunks.len();
|
||||
let chunks: Vec<Chunk> = code_chunks
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, cc)| {
|
||||
// All code chunks are classified as CodeBlock regardless of node type.
|
||||
let chunk_type = ChunkType::CodeBlock;
|
||||
|
||||
// Build heading context from context_path.
|
||||
let heading_context = if cc.metadata.context_path.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(HeadingContext {
|
||||
headings: cc
|
||||
.metadata
|
||||
.context_path
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(depth, name)| HeadingLevel {
|
||||
level: (depth as u8).saturating_add(2).min(6),
|
||||
text: name.clone(),
|
||||
})
|
||||
.collect(),
|
||||
})
|
||||
};
|
||||
|
||||
Chunk {
|
||||
content: cc.content.clone(),
|
||||
chunk_type,
|
||||
embedding: None,
|
||||
metadata: ChunkMetadata {
|
||||
byte_start: cc.start_byte,
|
||||
byte_end: cc.end_byte,
|
||||
token_count: None,
|
||||
chunk_index: i,
|
||||
total_chunks,
|
||||
first_page: None,
|
||||
last_page: None,
|
||||
heading_context,
|
||||
image_indices: Vec::new(),
|
||||
},
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Some(chunks)
|
||||
}
|
||||
|
||||
/// Execute chunking if configured.
|
||||
pub(super) fn execute_chunking(result: &mut ExtractionResult, config: &ExtractionConfig) -> Result<()> {
|
||||
#[cfg(feature = "chunking")]
|
||||
if let Some(ref chunking_config) = config.chunking {
|
||||
// For code extractions with TSLP chunks, bypass text-splitter and map directly.
|
||||
#[cfg(feature = "tree-sitter")]
|
||||
if let Some(code_chunks) = try_code_chunks(result) {
|
||||
result.chunks = Some(code_chunks);
|
||||
|
||||
let resolved_config = chunking_config.resolve_preset();
|
||||
#[cfg(feature = "embeddings")]
|
||||
if let Some(ref embedding_config) = resolved_config.embedding
|
||||
&& let Some(ref mut chunks) = result.chunks
|
||||
&& let Err(e) = crate::embeddings::generate_embeddings_for_chunks(chunks, embedding_config)
|
||||
{
|
||||
tracing::warn!("Embedding generation failed: {e}. Check that ONNX Runtime is installed.");
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("embedding"),
|
||||
message: Cow::Owned(e.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "embeddings"))]
|
||||
if resolved_config.embedding.is_some() {
|
||||
tracing::warn!(
|
||||
"Embedding config provided but embeddings feature is not enabled. Recompile with --features embeddings."
|
||||
);
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("embedding"),
|
||||
message: Cow::Borrowed("Embeddings feature not enabled"),
|
||||
});
|
||||
}
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let resolved_config = chunking_config.resolve_preset();
|
||||
let chunking_config = &resolved_config;
|
||||
|
||||
// Recompute page boundaries against `result.content` (rendered by `render_plain`)
|
||||
// if per-page content is available. The boundaries stored in
|
||||
// `result.metadata.pages.boundaries` were computed against the raw extractor text
|
||||
// and may have different byte offsets than the rendered content.
|
||||
let recomputed_boundaries: Option<Vec<PageBoundary>> = result
|
||||
.pages
|
||||
.as_deref()
|
||||
.map(|pages| recompute_boundaries_from_pages(&result.content, pages));
|
||||
|
||||
let page_boundaries: Option<&[PageBoundary]> = recomputed_boundaries
|
||||
.as_deref()
|
||||
.filter(|s| !s.is_empty())
|
||||
.or_else(|| result.metadata.pages.as_ref().and_then(|ps| ps.boundaries.as_deref()));
|
||||
|
||||
// Pass formatted_content (markdown) for heading context resolution when available.
|
||||
// Plain-text rendering strips heading markers, but the markdown chunker needs them
|
||||
// to build the heading hierarchy for chunk metadata.
|
||||
let heading_source = result.formatted_content.as_deref();
|
||||
match crate::chunking::chunk_text_with_heading_source(
|
||||
&result.content,
|
||||
chunking_config,
|
||||
page_boundaries,
|
||||
heading_source,
|
||||
) {
|
||||
Ok(chunking_result) => {
|
||||
result.chunks = Some(chunking_result.chunks);
|
||||
|
||||
// Populate image_indices on each chunk: collect indices of images whose
|
||||
// page_number falls within the chunk's [first_page, last_page] range.
|
||||
if let Some(ref images) = result.images
|
||||
&& let Some(ref mut chunks) = result.chunks
|
||||
{
|
||||
for chunk in chunks.iter_mut() {
|
||||
if let (Some(first), Some(last)) = (chunk.metadata.first_page, chunk.metadata.last_page) {
|
||||
chunk.metadata.image_indices = images
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, img)| {
|
||||
let pg = img.page_number?;
|
||||
(pg >= first && pg <= last).then_some(idx as u32)
|
||||
})
|
||||
.collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "embeddings")]
|
||||
if let Some(ref embedding_config) = chunking_config.embedding
|
||||
&& let Some(ref mut chunks) = result.chunks
|
||||
&& let Err(e) = crate::embeddings::generate_embeddings_for_chunks(chunks, embedding_config)
|
||||
{
|
||||
tracing::warn!("Embedding generation failed: {e}. Check that ONNX Runtime is installed.");
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("embedding"),
|
||||
message: Cow::Owned(e.to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "embeddings"))]
|
||||
if chunking_config.embedding.is_some() {
|
||||
tracing::warn!(
|
||||
"Embedding config provided but embeddings feature is not enabled. Recompile with --features embeddings."
|
||||
);
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("embedding"),
|
||||
message: Cow::Borrowed("Embeddings feature not enabled"),
|
||||
});
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("chunking"),
|
||||
message: Cow::Owned(e.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "chunking"))]
|
||||
if config.chunking.is_some() {
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("chunking"),
|
||||
message: Cow::Borrowed("Chunking feature not enabled"),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute language detection if configured.
|
||||
pub(super) fn execute_language_detection(result: &mut ExtractionResult, config: &ExtractionConfig) -> Result<()> {
|
||||
#[cfg(feature = "language-detection")]
|
||||
if let Some(ref lang_config) = config.language_detection {
|
||||
match crate::language_detection::detect_languages(&result.content, lang_config) {
|
||||
Ok(detected) => {
|
||||
result.detected_languages = detected;
|
||||
}
|
||||
Err(e) => {
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("language_detection"),
|
||||
message: Cow::Owned(e.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "language-detection"))]
|
||||
if config.language_detection.is_some() {
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("language_detection"),
|
||||
message: Cow::Borrowed("Language detection feature not enabled"),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Execute token reduction if configured.
|
||||
pub(super) fn execute_token_reduction(result: &mut ExtractionResult, config: &ExtractionConfig) -> Result<()> {
|
||||
#[cfg(feature = "quality")]
|
||||
if let Some(ref tr_config) = config.token_reduction {
|
||||
let level = crate::text::token_reduction::ReductionLevel::from(tr_config.mode.as_str());
|
||||
|
||||
if !matches!(level, crate::text::token_reduction::ReductionLevel::Off) {
|
||||
let impl_config = crate::text::token_reduction::TokenReductionConfig {
|
||||
level,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let lang_hint: Option<&str> = result
|
||||
.detected_languages
|
||||
.as_deref()
|
||||
.and_then(|langs| langs.first().map(|s| s.as_str()));
|
||||
|
||||
match crate::text::token_reduction::reduce_tokens(&result.content, &impl_config, lang_hint) {
|
||||
Ok(reduced) => {
|
||||
result.content = reduced;
|
||||
}
|
||||
Err(e) => {
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("token_reduction"),
|
||||
message: Cow::Owned(e.to_string()),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(not(feature = "quality"))]
|
||||
if config.token_reduction.is_some() {
|
||||
result.processing_warnings.push(ProcessingWarning {
|
||||
source: Cow::Borrowed("token_reduction"),
|
||||
message: Cow::Borrowed("Token reduction requires the quality feature"),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(all(test, feature = "chunking"))]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::PageContent;
|
||||
|
||||
fn make_page(page_number: u32, content: impl Into<String>) -> PageContent {
|
||||
PageContent {
|
||||
page_number,
|
||||
content: content.into(),
|
||||
tables: vec![],
|
||||
image_indices: vec![],
|
||||
hierarchy: None,
|
||||
is_blank: None,
|
||||
layout_regions: None,
|
||||
section_name: None,
|
||||
speaker_notes: None,
|
||||
sheet_name: None,
|
||||
}
|
||||
}
|
||||
|
||||
// When PageContent.content matches result.content exactly, all boundaries succeed.
|
||||
#[test]
|
||||
fn recompute_boundaries_exact_match_produces_full_boundary_set() {
|
||||
let p1 = "Hello world";
|
||||
let p2 = "Second page text";
|
||||
let p3 = "Third page here";
|
||||
let content = format!("{p1}\n\n{p2}\n\n{p3}");
|
||||
|
||||
let pages = vec![make_page(1, p1), make_page(2, p2), make_page(3, p3)];
|
||||
let boundaries = recompute_boundaries_from_pages(&content, &pages);
|
||||
|
||||
assert_eq!(boundaries.len(), 3, "all pages should resolve to boundaries");
|
||||
assert_eq!(&content[boundaries[0].byte_start..boundaries[0].byte_end], p1);
|
||||
assert_eq!(&content[boundaries[1].byte_start..boundaries[1].byte_end], p2);
|
||||
assert_eq!(&content[boundaries[2].byte_start..boundaries[2].byte_end], p3);
|
||||
}
|
||||
|
||||
// When PageContent.content is raw (control char present) but result.content has the
|
||||
// cleaned version, the affected page is silently skipped — leaving fewer boundaries
|
||||
// than pages. Documents the pre-fix failure mode.
|
||||
#[test]
|
||||
fn recompute_boundaries_raw_content_causes_skipped_pages() {
|
||||
// U+0001 between word chars → fix_pdf_control_chars replaces with '-'
|
||||
let p1_clean = "Hello world";
|
||||
let p2_raw = "ab\x01cd"; // raw page text — control char present
|
||||
let p2_clean = "ab-cd"; // what result.content contains after cleanup
|
||||
let p3_clean = "Third page";
|
||||
let content = format!("{p1_clean}\n\n{p2_clean}\n\n{p3_clean}");
|
||||
|
||||
// Pre-fix scenario: page.content = raw, result.content = cleaned → mismatch
|
||||
let pages = vec![
|
||||
make_page(1, p1_clean),
|
||||
make_page(2, p2_raw), // intentionally stale raw content
|
||||
make_page(3, p3_clean),
|
||||
];
|
||||
let boundaries = recompute_boundaries_from_pages(&content, &pages);
|
||||
|
||||
// Page 2 is skipped: neither exact nor first-line search finds "ab\x01cd"
|
||||
// inside content (which has "ab-cd"). Only pages 1 and 3 resolve.
|
||||
assert_eq!(boundaries.len(), 2, "page with raw/cleaned mismatch should be skipped");
|
||||
assert_eq!(boundaries[0].page_number, 1);
|
||||
assert_eq!(boundaries[1].page_number, 3);
|
||||
}
|
||||
|
||||
// When PageContent.content is the cleaned text (the fix), all pages resolve.
|
||||
#[test]
|
||||
fn recompute_boundaries_cleaned_content_resolves_all_pages() {
|
||||
let p1_clean = "Hello world";
|
||||
let p2_clean = "ab-cd"; // cleaned — matches result.content exactly
|
||||
let p3_clean = "Third page";
|
||||
let content = format!("{p1_clean}\n\n{p2_clean}\n\n{p3_clean}");
|
||||
|
||||
// Post-fix scenario: page.content = cleaned, result.content = cleaned → exact match
|
||||
let pages = vec![make_page(1, p1_clean), make_page(2, p2_clean), make_page(3, p3_clean)];
|
||||
let boundaries = recompute_boundaries_from_pages(&content, &pages);
|
||||
|
||||
assert_eq!(boundaries.len(), 3, "all pages should resolve after fix");
|
||||
assert_eq!(&content[boundaries[1].byte_start..boundaries[1].byte_end], p2_clean);
|
||||
}
|
||||
|
||||
// PDF pages often have trailing spaces before "\n\n" paragraph separators (PDF
|
||||
// rendering artifact). render_plain trims each paragraph via paragraph.trim(),
|
||||
// so result.content lacks those trailing spaces while page.content retains them.
|
||||
// The normalised-exact match must succeed and produce correct byte_end so that
|
||||
// subsequent pages are found without cascading search_offset over-advance.
|
||||
#[test]
|
||||
fn recompute_boundaries_trailing_space_pages_all_resolve() {
|
||||
// Simulate PDF page content with trailing spaces before "\n\n".
|
||||
let p1_raw = "Heading \n\nBody paragraph one. ";
|
||||
let p2_raw = "Second heading \n\nBody paragraph two. ";
|
||||
let p3_raw = "Conclusion. ";
|
||||
|
||||
// result.content as render_plain produces it (each paragraph trimmed).
|
||||
let p1_norm = "Heading\n\nBody paragraph one.";
|
||||
let p2_norm = "Second heading\n\nBody paragraph two.";
|
||||
let p3_norm = "Conclusion.";
|
||||
let content = format!("{p1_norm}\n\n{p2_norm}\n\n{p3_norm}");
|
||||
|
||||
let pages = vec![make_page(1, p1_raw), make_page(2, p2_raw), make_page(3, p3_raw)];
|
||||
let boundaries = recompute_boundaries_from_pages(&content, &pages);
|
||||
|
||||
assert_eq!(boundaries.len(), 3, "all pages must resolve despite trailing spaces");
|
||||
assert_eq!(&content[boundaries[0].byte_start..boundaries[0].byte_end], p1_norm);
|
||||
assert_eq!(&content[boundaries[1].byte_start..boundaries[1].byte_end], p2_norm);
|
||||
assert_eq!(&content[boundaries[2].byte_start..boundaries[2].byte_end], p3_norm);
|
||||
}
|
||||
}
|
||||
207
crates/kreuzberg/src/core/pipeline/format.rs
Normal file
207
crates/kreuzberg/src/core/pipeline/format.rs
Normal file
@@ -0,0 +1,207 @@
|
||||
//! Output format conversion for extraction results.
|
||||
//!
|
||||
//! This module handles the final step of output format application: swapping
|
||||
//! pre-rendered content into the result and recording format metadata.
|
||||
//!
|
||||
//! The heavy rendering work (Markdown, Djot, HTML) is now done earlier in the
|
||||
//! pipeline inside `derive_extraction_result`, which populates
|
||||
//! `ExtractionResult::formatted_content`. This function simply swaps that
|
||||
//! pre-rendered content into the `content` field after post-processors have
|
||||
//! operated on the plain-text version.
|
||||
|
||||
use crate::core::config::OutputFormat;
|
||||
use crate::types::ExtractionResult;
|
||||
#[cfg(test)]
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// Apply output format conversion to the extraction result.
|
||||
///
|
||||
/// Records the output format in metadata and swaps in pre-rendered content
|
||||
/// (produced during `derive_extraction_result`) if available.
|
||||
///
|
||||
/// This runs as the final pipeline step, after post-processors have operated
|
||||
/// on the plain-text `content` field.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `result` - The extraction result to modify
|
||||
/// * `output_format` - The desired output format
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn apply_output_format(result: ExtractionResult, output_format: OutputFormat) -> ExtractionResult {
|
||||
let mut result = result;
|
||||
let format_name = match output_format {
|
||||
OutputFormat::Plain => "plain",
|
||||
OutputFormat::Markdown => "markdown",
|
||||
OutputFormat::Djot => "djot",
|
||||
OutputFormat::Html => "html",
|
||||
OutputFormat::Json => "json",
|
||||
OutputFormat::Structured => "structured",
|
||||
OutputFormat::Custom(ref name) => name.as_str(),
|
||||
};
|
||||
result.metadata.output_format = Some(format_name.to_string());
|
||||
|
||||
// Swap in pre-rendered content if available (populated by derive_extraction_result).
|
||||
if let Some(formatted) = result.formatted_content.take() {
|
||||
result.content = formatted;
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::types::Metadata;
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_plain() {
|
||||
let result = ExtractionResult {
|
||||
content: "Hello World".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Plain);
|
||||
|
||||
assert_eq!(result.content, "Hello World");
|
||||
assert_eq!(result.metadata.output_format, Some("plain".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_markdown_no_prerender() {
|
||||
let result = ExtractionResult {
|
||||
content: "Hello World".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Markdown);
|
||||
|
||||
// Without pre-rendered content, content stays as-is
|
||||
assert_eq!(result.content, "Hello World");
|
||||
assert_eq!(result.metadata.output_format, Some("markdown".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_swaps_formatted_content() {
|
||||
let result = ExtractionResult {
|
||||
content: "plain text".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
formatted_content: Some("# Heading\n\nFormatted markdown".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Markdown);
|
||||
|
||||
assert_eq!(result.content, "# Heading\n\nFormatted markdown");
|
||||
assert!(result.formatted_content.is_none(), "formatted_content should be taken");
|
||||
assert_eq!(result.metadata.output_format, Some("markdown".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_html_with_prerender() {
|
||||
let result = ExtractionResult {
|
||||
content: "plain text".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
formatted_content: Some("<p>Hello World</p>".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Html);
|
||||
|
||||
assert_eq!(result.content, "<p>Hello World</p>");
|
||||
assert_eq!(result.metadata.output_format, Some("html".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_djot_with_prerender() {
|
||||
let result = ExtractionResult {
|
||||
content: "plain text".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
formatted_content: Some("# Djot heading".to_string()),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Djot);
|
||||
|
||||
assert_eq!(result.content, "# Djot heading");
|
||||
assert_eq!(result.metadata.output_format, Some("djot".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_structured() {
|
||||
let result = ExtractionResult {
|
||||
content: "Hello World".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Structured);
|
||||
|
||||
assert_eq!(result.content, "Hello World");
|
||||
assert_eq!(result.metadata.output_format, Some("structured".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_preserves_metadata() {
|
||||
use ahash::AHashMap;
|
||||
let mut additional = AHashMap::new();
|
||||
additional.insert(Cow::Borrowed("custom_key"), serde_json::json!("custom_value"));
|
||||
let metadata = Metadata {
|
||||
title: Some("Test Title".to_string()),
|
||||
additional,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = ExtractionResult {
|
||||
content: "Hello World".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
metadata,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Markdown);
|
||||
|
||||
assert_eq!(result.metadata.title, Some("Test Title".to_string()));
|
||||
assert_eq!(
|
||||
result.metadata.additional.get("custom_key"),
|
||||
Some(&serde_json::json!("custom_value"))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_preserves_tables() {
|
||||
use crate::types::Table;
|
||||
|
||||
let table = Table {
|
||||
cells: vec![vec!["A".to_string(), "B".to_string()]],
|
||||
markdown: "| A | B |".to_string(),
|
||||
page_number: 1,
|
||||
bounding_box: None,
|
||||
};
|
||||
|
||||
let result = ExtractionResult {
|
||||
content: "Hello World".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
tables: vec![table],
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Html);
|
||||
|
||||
assert_eq!(result.tables.len(), 1);
|
||||
assert_eq!(result.tables[0].cells[0][0], "A");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_apply_output_format_sets_typed_field() {
|
||||
let result = ExtractionResult {
|
||||
content: "test".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = apply_output_format(result, OutputFormat::Djot);
|
||||
|
||||
assert_eq!(result.metadata.output_format, Some("djot".to_string()));
|
||||
}
|
||||
}
|
||||
67
crates/kreuzberg/src/core/pipeline/initialization.rs
Normal file
67
crates/kreuzberg/src/core/pipeline/initialization.rs
Normal file
@@ -0,0 +1,67 @@
|
||||
//! Pipeline initialization and setup logic.
|
||||
//!
|
||||
//! This module handles the initialization of features and processor cache
|
||||
//! required for pipeline execution.
|
||||
|
||||
use crate::Result;
|
||||
#[cfg(feature = "quality")]
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use super::cache::{PROCESSOR_CACHE, ProcessorCache};
|
||||
|
||||
/// Type alias for processor stages tuple (Early, Middle, Late).
|
||||
type ProcessorStages = (
|
||||
std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
|
||||
std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
|
||||
std::sync::Arc<Vec<std::sync::Arc<dyn crate::plugins::PostProcessor>>>,
|
||||
);
|
||||
|
||||
/// Initialize feature-specific systems that may be needed during pipeline execution.
|
||||
pub(super) fn initialize_features() {
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
{
|
||||
let _ = crate::keywords::ensure_initialized();
|
||||
}
|
||||
|
||||
#[cfg(feature = "language-detection")]
|
||||
{
|
||||
let _ = crate::language_detection::ensure_initialized();
|
||||
}
|
||||
|
||||
#[cfg(feature = "chunking")]
|
||||
{
|
||||
let _ = crate::chunking::ensure_initialized();
|
||||
}
|
||||
|
||||
#[cfg(feature = "quality")]
|
||||
{
|
||||
static QUALITY_INIT: OnceLock<()> = OnceLock::new();
|
||||
QUALITY_INIT.get_or_init(|| {
|
||||
let registry = crate::plugins::registry::get_post_processor_registry();
|
||||
let mut reg = registry.write();
|
||||
let _ = reg.register(std::sync::Arc::new(crate::text::QualityProcessor));
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the processor cache if not already initialized.
|
||||
pub(super) fn initialize_processor_cache() -> Result<()> {
|
||||
let mut cache_lock = PROCESSOR_CACHE.write();
|
||||
if cache_lock.is_none() {
|
||||
*cache_lock = Some(ProcessorCache::new()?);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get processors from the cache, organized by stage.
|
||||
pub(super) fn get_processors_from_cache() -> Result<ProcessorStages> {
|
||||
let cache_lock = PROCESSOR_CACHE.read();
|
||||
let cache = cache_lock
|
||||
.as_ref()
|
||||
.ok_or_else(|| crate::KreuzbergError::Other("Processor cache not initialized".to_string()))?;
|
||||
Ok((
|
||||
std::sync::Arc::clone(&cache.early),
|
||||
std::sync::Arc::clone(&cache.middle),
|
||||
std::sync::Arc::clone(&cache.late),
|
||||
))
|
||||
}
|
||||
478
crates/kreuzberg/src/core/pipeline/mod.rs
Normal file
478
crates/kreuzberg/src/core/pipeline/mod.rs
Normal file
@@ -0,0 +1,478 @@
|
||||
//! Post-processing pipeline orchestration.
|
||||
//!
|
||||
//! This module orchestrates the post-processing pipeline, executing validators,
|
||||
//! quality processing, chunking, and custom hooks in the correct order.
|
||||
|
||||
mod cache;
|
||||
mod execution;
|
||||
mod features;
|
||||
mod format;
|
||||
mod initialization;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
pub use cache::clear_processor_cache;
|
||||
pub use format::apply_output_format;
|
||||
|
||||
use crate::Result;
|
||||
use crate::core::config::ExtractionConfig;
|
||||
use crate::types::ExtractionResult;
|
||||
use crate::types::internal::InternalDocument;
|
||||
|
||||
use execution::{execute_processors, execute_validators};
|
||||
use features::{execute_chunking, execute_language_detection, execute_token_reduction};
|
||||
use initialization::{get_processors_from_cache, initialize_features, initialize_processor_cache};
|
||||
|
||||
/// Run the post-processing pipeline on an `InternalDocument`.
|
||||
///
|
||||
/// Derives `ExtractionResult` from `InternalDocument` via the derivation pipeline,
|
||||
/// then executes post-processing in the following order:
|
||||
/// 1. Post-Processors - Execute by stage (Early, Middle, Late) to modify/enhance the result
|
||||
/// 2. Quality Processing - Text cleaning and quality scoring
|
||||
/// 3. Chunking - Text splitting if enabled
|
||||
/// 4. Validators - Run validation hooks on the processed result (can fail fast)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `doc` - The internal document produced by the extractor
|
||||
/// * `config` - Extraction configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The processed extraction result.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// - Validator errors bubble up immediately
|
||||
/// - Post-processor errors are caught and recorded in metadata
|
||||
/// - System errors (IO, RuntimeError equivalents) always bubble up
|
||||
#[cfg_attr(feature = "otel", tracing::instrument(
|
||||
skip(doc, config),
|
||||
fields(
|
||||
pipeline.stage = "post_processing",
|
||||
content.element_count = doc.elements.len(),
|
||||
)
|
||||
))]
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub async fn run_pipeline(mut doc: InternalDocument, config: &ExtractionConfig) -> Result<ExtractionResult> {
|
||||
// Propagate rendering preferences from config into the document.
|
||||
doc.ocr_text_only = config.images.as_ref().map(|i| i.ocr_text_only).unwrap_or(false);
|
||||
doc.append_ocr_text = config.images.as_ref().map(|i| i.append_ocr_text).unwrap_or(false);
|
||||
|
||||
// 1. Process extracted images with OCR if configured
|
||||
#[cfg(all(feature = "ocr", feature = "tokio-runtime"))]
|
||||
let image_ocr_enabled = config.images.as_ref().map(|i| i.run_ocr_on_images).unwrap_or(true);
|
||||
#[cfg(all(feature = "ocr", feature = "tokio-runtime"))]
|
||||
if image_ocr_enabled && config.ocr.is_some() && !doc.images.is_empty() {
|
||||
let images_to_process = std::mem::take(&mut doc.images);
|
||||
match crate::extraction::image_ocr::process_images_with_ocr(
|
||||
images_to_process,
|
||||
config,
|
||||
&mut doc.processing_warnings,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(processed) => {
|
||||
doc.images = processed;
|
||||
}
|
||||
Err(e) => {
|
||||
doc.processing_warnings.push(crate::types::ProcessingWarning {
|
||||
source: std::borrow::Cow::Borrowed("image_ocr"),
|
||||
message: std::borrow::Cow::Owned(format!("Image OCR failed: {e}")),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
replace_embedded_image_markdown_with_ocr(&mut doc);
|
||||
append_embedded_image_ocr_text(&mut doc);
|
||||
|
||||
// Pre-render markdown for the chunker's heading context resolution when:
|
||||
// - Markdown chunking is configured
|
||||
// - Output format is not already Markdown (which would produce formatted_content anyway)
|
||||
// Plain-text rendering strips heading markers, so the markdown chunker needs
|
||||
// a separate markdown rendering to build the heading hierarchy for chunk metadata.
|
||||
#[cfg(feature = "chunking")]
|
||||
let chunker_heading_source = {
|
||||
let needs_markdown = config.chunking.as_ref().is_some_and(|c| {
|
||||
c.chunker_type == crate::core::config::ChunkerType::Markdown
|
||||
|| c.resolve_preset().chunker_type == crate::core::config::ChunkerType::Markdown
|
||||
}) && config.output_format == crate::core::config::OutputFormat::Plain;
|
||||
if needs_markdown {
|
||||
Some(crate::rendering::render_markdown(&doc))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Pre-render styled HTML before `doc` is consumed by `derive_extraction_result`.
|
||||
// When `html` is active and the caller has configured `html_output`, we
|
||||
// render the document here and inject the result after derivation.
|
||||
#[cfg(feature = "html")]
|
||||
let styled_html_prerender: Option<String> = {
|
||||
use crate::plugins::Renderer as _;
|
||||
if config.output_format == crate::core::config::OutputFormat::Html {
|
||||
config.html_output.as_ref().and_then(|html_cfg| {
|
||||
match crate::rendering::StyledHtmlRenderer::new(html_cfg.clone()) {
|
||||
Ok(renderer) => match renderer.render(&doc) {
|
||||
Ok(html) => Some(html),
|
||||
Err(e) => {
|
||||
tracing::warn!("StyledHtmlRenderer render failed, falling back to default HTML: {e}");
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("StyledHtmlRenderer construction failed, falling back to default HTML: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// 2. Derive ExtractionResult from InternalDocument
|
||||
let include_structure = config.include_document_structure;
|
||||
let mut result =
|
||||
crate::extraction::derive::derive_extraction_result(doc, include_structure, config.output_format.clone());
|
||||
|
||||
// Inject pre-rendered styled HTML (overrides the default render_html output).
|
||||
#[cfg(feature = "html")]
|
||||
if let Some(html) = styled_html_prerender {
|
||||
result.formatted_content = Some(html);
|
||||
}
|
||||
|
||||
// Temporarily store pre-rendered markdown for chunker heading context.
|
||||
// Tracked separately so we can remove it after chunking — apply_output_format
|
||||
// must not swap this into result.content when output_format is Plain.
|
||||
#[cfg(feature = "chunking")]
|
||||
let chunker_only_markdown = result.formatted_content.is_none();
|
||||
#[cfg(feature = "chunking")]
|
||||
if chunker_only_markdown && let Some(md) = chunker_heading_source {
|
||||
result.formatted_content = Some(md);
|
||||
}
|
||||
|
||||
// 2. Run post-processing pipeline
|
||||
let pp_config = config.postprocessor.as_ref();
|
||||
let postprocessing_enabled = pp_config.is_none_or(|c| c.enabled);
|
||||
|
||||
if postprocessing_enabled {
|
||||
initialize_features();
|
||||
initialize_processor_cache()?;
|
||||
|
||||
let (early_processors, middle_processors, late_processors) = get_processors_from_cache()?;
|
||||
|
||||
execute_processors(
|
||||
&mut result,
|
||||
config,
|
||||
&pp_config,
|
||||
early_processors,
|
||||
middle_processors,
|
||||
late_processors,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
execute_chunking(&mut result, config)?;
|
||||
|
||||
// Clear temporary markdown if it was only stored for chunker heading context.
|
||||
// This prevents apply_output_format from swapping it into result.content.
|
||||
#[cfg(feature = "chunking")]
|
||||
if chunker_only_markdown {
|
||||
result.formatted_content = None;
|
||||
}
|
||||
|
||||
execute_language_detection(&mut result, config)?;
|
||||
execute_token_reduction(&mut result, config)?;
|
||||
execute_validators(&result, config).await?;
|
||||
|
||||
apply_element_transform(&mut result, config);
|
||||
normalize_nfc(&mut result);
|
||||
|
||||
// Run LLM-based structured extraction BEFORE output formatting
|
||||
// so extraction sees plain text, not markdown/HTML
|
||||
// TODO(wasm-llm): hosted structured extraction should run on wasm through
|
||||
// liter-llm's wasm-http backend once browser/runtime support is wired.
|
||||
#[cfg(all(feature = "liter-llm", not(target_os = "windows"), not(target_arch = "wasm32")))]
|
||||
if let Some(ref structured_config) = config.structured_extraction {
|
||||
match crate::llm::structured::extract_structured(&result.content, structured_config).await {
|
||||
Ok((output, usage)) => {
|
||||
result.structured_output = Some(output);
|
||||
crate::llm::usage::push_llm_usage(&mut result, usage);
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Structured extraction failed: {e}");
|
||||
result.processing_warnings.push(crate::types::ProcessingWarning {
|
||||
source: std::borrow::Cow::Borrowed("structured_extraction"),
|
||||
message: std::borrow::Cow::Owned(format!("Structured extraction failed: {e}")),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(wasm-llm): keep wasm in the fallback branch until structured
|
||||
// extraction has an async wasm-compatible runtime path.
|
||||
#[cfg(any(not(feature = "liter-llm"), target_os = "windows", target_arch = "wasm32"))]
|
||||
if config.structured_extraction.is_some() {
|
||||
result.processing_warnings.push(crate::types::ProcessingWarning {
|
||||
source: std::borrow::Cow::Borrowed("structured_extraction"),
|
||||
message: std::borrow::Cow::Borrowed("Structured extraction requires the 'liter-llm' feature"),
|
||||
});
|
||||
}
|
||||
|
||||
// Apply output format conversion as the final step
|
||||
result = apply_output_format(result, config.output_format.clone());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Run the post-processing pipeline synchronously (WASM-compatible version).
|
||||
///
|
||||
/// This is a synchronous implementation for WASM and non-async contexts.
|
||||
/// It performs a subset of the full async pipeline, excluding async post-processors
|
||||
/// and validators.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `doc` - The internal document produced by the extractor
|
||||
/// * `config` - Extraction configuration
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// The processed extraction result.
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// This function is only available when the `tokio-runtime` feature is disabled.
|
||||
/// It handles:
|
||||
/// - Quality processing (if enabled)
|
||||
/// - Chunking (if enabled)
|
||||
/// - Language detection (if enabled)
|
||||
///
|
||||
/// It does NOT handle:
|
||||
/// - Async post-processors
|
||||
/// - Async validators
|
||||
#[cfg(not(feature = "tokio-runtime"))]
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn run_pipeline_sync(doc: InternalDocument, config: &ExtractionConfig) -> Result<ExtractionResult> {
|
||||
// Pre-render markdown for chunker heading context (same logic as async path).
|
||||
#[cfg(feature = "chunking")]
|
||||
let chunker_heading_source = {
|
||||
let needs_markdown = config.chunking.as_ref().is_some_and(|c| {
|
||||
c.chunker_type == crate::core::config::ChunkerType::Markdown
|
||||
|| c.resolve_preset().chunker_type == crate::core::config::ChunkerType::Markdown
|
||||
}) && config.output_format == crate::core::config::OutputFormat::Plain;
|
||||
if needs_markdown {
|
||||
Some(crate::rendering::render_markdown(&doc))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// Pre-render styled HTML before `doc` is consumed (mirrors async path).
|
||||
#[cfg(feature = "html")]
|
||||
let styled_html_prerender: Option<String> = {
|
||||
use crate::plugins::Renderer as _;
|
||||
if config.output_format == crate::core::config::OutputFormat::Html {
|
||||
config.html_output.as_ref().and_then(|html_cfg| {
|
||||
match crate::rendering::StyledHtmlRenderer::new(html_cfg.clone()) {
|
||||
Ok(renderer) => match renderer.render(&doc) {
|
||||
Ok(html) => Some(html),
|
||||
Err(e) => {
|
||||
tracing::warn!("StyledHtmlRenderer render failed, falling back to default HTML: {e}");
|
||||
None
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("StyledHtmlRenderer construction failed, falling back to default HTML: {e}");
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// 1. Derive ExtractionResult from InternalDocument
|
||||
let include_structure = config.include_document_structure;
|
||||
let mut result =
|
||||
crate::extraction::derive::derive_extraction_result(doc, include_structure, config.output_format.clone());
|
||||
|
||||
// Inject pre-rendered styled HTML.
|
||||
#[cfg(feature = "html")]
|
||||
if let Some(html) = styled_html_prerender {
|
||||
result.formatted_content = Some(html);
|
||||
}
|
||||
|
||||
#[cfg(feature = "chunking")]
|
||||
let chunker_only_markdown = result.formatted_content.is_none();
|
||||
#[cfg(feature = "chunking")]
|
||||
if chunker_only_markdown && let Some(md) = chunker_heading_source {
|
||||
result.formatted_content = Some(md);
|
||||
}
|
||||
|
||||
// 2. Run synchronous post-processing
|
||||
execute_chunking(&mut result, config)?;
|
||||
|
||||
#[cfg(feature = "chunking")]
|
||||
if chunker_only_markdown {
|
||||
result.formatted_content = None;
|
||||
}
|
||||
|
||||
execute_language_detection(&mut result, config)?;
|
||||
execute_token_reduction(&mut result, config)?;
|
||||
|
||||
apply_element_transform(&mut result, config);
|
||||
normalize_nfc(&mut result);
|
||||
|
||||
// Apply output format conversion as the final step
|
||||
result = apply_output_format(result, config.output_format.clone());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
/// Transform to element-based output if requested by the config.
|
||||
fn apply_element_transform(result: &mut ExtractionResult, config: &ExtractionConfig) {
|
||||
if config.result_format == crate::types::ResultFormat::ElementBased {
|
||||
result.elements = Some(crate::extraction::transform::transform_extraction_result_to_elements(
|
||||
result,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
/// Replace inline markdown image references with OCR text for formats (e.g. PPTX)
|
||||
/// that bake placeholders into paragraph text rather than using `ElementKind::Image`.
|
||||
fn replace_embedded_image_markdown_with_ocr(doc: &mut InternalDocument) {
|
||||
if !doc.ocr_text_only || doc.images.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut image_idx = 0usize;
|
||||
|
||||
for elem in &mut doc.elements {
|
||||
if !matches!(elem.kind, crate::types::internal::ElementKind::Paragraph) {
|
||||
continue;
|
||||
}
|
||||
if !is_markdown_image_reference(&elem.text) {
|
||||
continue;
|
||||
}
|
||||
if let Some(img) = doc.images.get(image_idx)
|
||||
&& let Some(ocr) = &img.ocr_result
|
||||
&& !ocr.content.is_empty()
|
||||
{
|
||||
elem.text = ocr.content.clone();
|
||||
image_idx += 1;
|
||||
continue;
|
||||
}
|
||||
image_idx += 1;
|
||||
}
|
||||
|
||||
for table in &mut doc.tables {
|
||||
for row in &mut table.cells {
|
||||
for cell in row {
|
||||
if !is_markdown_image_reference(cell) {
|
||||
continue;
|
||||
}
|
||||
if let Some(img) = doc.images.get(image_idx)
|
||||
&& let Some(ocr) = &img.ocr_result
|
||||
&& !ocr.content.is_empty()
|
||||
{
|
||||
*cell = ocr.content.clone();
|
||||
image_idx += 1;
|
||||
continue;
|
||||
}
|
||||
image_idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Append OCR text after inline markdown image references for formats (e.g. PPTX)
|
||||
/// that bake placeholders into paragraph text. Only runs when `append_ocr_text` is
|
||||
/// `true` and `ocr_text_only` is `false`.
|
||||
fn append_embedded_image_ocr_text(doc: &mut InternalDocument) {
|
||||
if doc.ocr_text_only || !doc.append_ocr_text || doc.images.is_empty() {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut image_idx = 0usize;
|
||||
let mut new_elements = Vec::with_capacity(doc.elements.len() * 2);
|
||||
|
||||
for elem in &doc.elements {
|
||||
new_elements.push(elem.clone());
|
||||
|
||||
if matches!(elem.kind, crate::types::internal::ElementKind::Paragraph)
|
||||
&& is_markdown_image_reference(&elem.text)
|
||||
{
|
||||
if let Some(img) = doc.images.get(image_idx)
|
||||
&& let Some(ocr) = &img.ocr_result
|
||||
&& !ocr.content.is_empty()
|
||||
{
|
||||
let ocr_elem = crate::types::internal::InternalElement::text(
|
||||
crate::types::internal::ElementKind::Paragraph,
|
||||
ocr.content.clone(),
|
||||
0,
|
||||
);
|
||||
new_elements.push(ocr_elem);
|
||||
}
|
||||
image_idx += 1;
|
||||
}
|
||||
}
|
||||
|
||||
doc.elements = new_elements;
|
||||
|
||||
for table in &mut doc.tables {
|
||||
for row in &mut table.cells {
|
||||
for cell in row {
|
||||
if !is_markdown_image_reference(cell) {
|
||||
continue;
|
||||
}
|
||||
if let Some(img) = doc.images.get(image_idx)
|
||||
&& let Some(ocr) = &img.ocr_result
|
||||
&& !ocr.content.is_empty()
|
||||
{
|
||||
*cell = format!("{}\n\n{}", cell.trim(), ocr.content);
|
||||
}
|
||||
image_idx += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns `true` if `text` is exactly a markdown image reference (``).
|
||||
fn is_markdown_image_reference(text: &str) -> bool {
|
||||
let t = text.trim();
|
||||
if !t.starts_with(" else {
|
||||
return false;
|
||||
};
|
||||
if bracket_end < 2 {
|
||||
return false;
|
||||
}
|
||||
let after = &t[bracket_end + 2..];
|
||||
after.ends_with(')')
|
||||
}
|
||||
|
||||
/// Apply NFC unicode normalization to all text content.
|
||||
///
|
||||
/// Ensures consistent representation of composed characters (e.g., é vs e+combining accent)
|
||||
/// across all extraction backends (PDF, OCR, DOCX, HTML, etc.).
|
||||
fn normalize_nfc(result: &mut ExtractionResult) {
|
||||
#[cfg(feature = "quality")]
|
||||
{
|
||||
use unicode_normalization::UnicodeNormalization;
|
||||
result.content = result.content.nfc().collect();
|
||||
if let Some(pages) = result.pages.as_mut() {
|
||||
for page in pages.iter_mut() {
|
||||
page.content = page.content.nfc().collect();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Suppress unused variable warning when quality feature is disabled
|
||||
let _ = result;
|
||||
}
|
||||
993
crates/kreuzberg/src/core/pipeline/tests.rs
Normal file
993
crates/kreuzberg/src/core/pipeline/tests.rs
Normal file
@@ -0,0 +1,993 @@
|
||||
//! Pipeline orchestration tests.
|
||||
|
||||
use super::*;
|
||||
use crate::core::config::OutputFormat;
|
||||
use crate::types::Metadata;
|
||||
use crate::types::internal::{ElementKind, InternalDocument, InternalElement};
|
||||
use serial_test::serial;
|
||||
use std::borrow::Cow;
|
||||
|
||||
/// Build an `InternalDocument` with a single paragraph element for pipeline tests.
|
||||
fn make_doc(content: &str, mime: &str) -> InternalDocument {
|
||||
let mut doc = InternalDocument::new("plain");
|
||||
doc.mime_type = mime.to_string();
|
||||
if !content.is_empty() {
|
||||
doc.push_element(InternalElement::text(ElementKind::Paragraph, content, 0));
|
||||
}
|
||||
doc
|
||||
}
|
||||
|
||||
/// Build an `InternalDocument` with content, mime, and custom metadata.
|
||||
fn make_doc_with_metadata(content: &str, mime: &str, metadata: Metadata) -> InternalDocument {
|
||||
let mut doc = make_doc(content, mime);
|
||||
doc.metadata = metadata;
|
||||
doc
|
||||
}
|
||||
|
||||
const VALIDATION_MARKER_KEY: &str = "registry_validation_marker";
|
||||
#[cfg(feature = "quality")]
|
||||
const QUALITY_VALIDATION_MARKER: &str = "quality_validation_test";
|
||||
const POSTPROCESSOR_VALIDATION_MARKER: &str = "postprocessor_validation_test";
|
||||
const ORDER_VALIDATION_MARKER: &str = "order_validation_test";
|
||||
|
||||
/// Ensure the quality processor is registered and cache is fresh.
|
||||
/// Needed because other tests may call `shutdown_all()` on the registry,
|
||||
/// and the `OnceLock` in `initialize_features` prevents re-registration.
|
||||
#[cfg(feature = "quality")]
|
||||
fn ensure_quality_processor() {
|
||||
let registry = crate::plugins::registry::get_post_processor_registry();
|
||||
let mut reg = registry.write();
|
||||
let _ = reg.register(std::sync::Arc::new(crate::text::QualityProcessor));
|
||||
drop(reg);
|
||||
let _ = clear_processor_cache();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_run_pipeline_basic() {
|
||||
let mut doc = make_doc("test", "text/plain");
|
||||
doc.metadata.additional.insert(
|
||||
Cow::Borrowed(VALIDATION_MARKER_KEY),
|
||||
serde_json::json!(ORDER_VALIDATION_MARKER),
|
||||
);
|
||||
let config = ExtractionConfig {
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert_eq!(processed.content, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(feature = "quality")]
|
||||
async fn test_pipeline_with_quality_processing() {
|
||||
ensure_quality_processor();
|
||||
let doc = make_doc("This is a test document with some meaningful content.", "text/plain");
|
||||
let config = ExtractionConfig {
|
||||
enable_quality_processing: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert!(processed.quality_score.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_pipeline_without_quality_processing() {
|
||||
let doc = make_doc("test", "text/plain");
|
||||
let config = ExtractionConfig {
|
||||
enable_quality_processing: false,
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert!(processed.quality_score.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(feature = "chunking")]
|
||||
async fn test_pipeline_with_chunking() {
|
||||
let doc = make_doc(
|
||||
&"This is a long text that should be chunked. ".repeat(100),
|
||||
"text/plain",
|
||||
);
|
||||
let config = ExtractionConfig {
|
||||
chunking: Some(crate::ChunkingConfig {
|
||||
max_characters: 500,
|
||||
overlap: 50,
|
||||
trim: true,
|
||||
chunker_type: crate::ChunkerType::Text,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
let chunks = processed.chunks.as_ref().expect("chunks should be present");
|
||||
assert!(chunks.len() > 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_pipeline_without_chunking() {
|
||||
let doc = make_doc("test", "text/plain");
|
||||
let config = ExtractionConfig {
|
||||
chunking: None,
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert!(processed.chunks.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_pipeline_preserves_metadata() {
|
||||
use ahash::AHashMap;
|
||||
let mut additional = AHashMap::new();
|
||||
additional.insert(Cow::Borrowed("source"), serde_json::json!("test"));
|
||||
additional.insert(Cow::Borrowed("page"), serde_json::json!(1));
|
||||
|
||||
let doc = make_doc_with_metadata(
|
||||
"test",
|
||||
"text/plain",
|
||||
Metadata {
|
||||
additional,
|
||||
..Default::default()
|
||||
},
|
||||
);
|
||||
let config = ExtractionConfig {
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert_eq!(
|
||||
processed.metadata.additional.get("source").unwrap(),
|
||||
&serde_json::json!("test")
|
||||
);
|
||||
assert_eq!(
|
||||
processed.metadata.additional.get("page").unwrap(),
|
||||
&serde_json::json!(1)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_pipeline_preserves_tables() {
|
||||
use crate::types::Table;
|
||||
|
||||
let table = Table {
|
||||
cells: vec![vec!["A".to_string(), "B".to_string()]],
|
||||
markdown: "| A | B |".to_string(),
|
||||
page_number: 0,
|
||||
bounding_box: None,
|
||||
};
|
||||
|
||||
let mut doc = make_doc("test", "text/plain");
|
||||
doc.tables.push(table);
|
||||
let config = ExtractionConfig {
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert_eq!(processed.tables.len(), 1);
|
||||
assert_eq!(processed.tables[0].cells.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_pipeline_empty_content() {
|
||||
{
|
||||
let registry = crate::plugins::registry::get_post_processor_registry();
|
||||
registry.write().shutdown_all().unwrap();
|
||||
}
|
||||
{
|
||||
let registry = crate::plugins::registry::get_validator_registry();
|
||||
registry.write().shutdown_all().unwrap();
|
||||
}
|
||||
|
||||
let doc = make_doc("", "text/plain");
|
||||
let config = ExtractionConfig::default();
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert_eq!(processed.content, "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(feature = "chunking")]
|
||||
async fn test_pipeline_with_all_features() {
|
||||
#[cfg(feature = "quality")]
|
||||
ensure_quality_processor();
|
||||
let doc = make_doc(&"This is a comprehensive test document. ".repeat(50), "text/plain");
|
||||
let config = ExtractionConfig {
|
||||
enable_quality_processing: true,
|
||||
chunking: Some(crate::ChunkingConfig {
|
||||
max_characters: 500,
|
||||
overlap: 50,
|
||||
trim: true,
|
||||
chunker_type: crate::ChunkerType::Text,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
#[cfg(feature = "quality")]
|
||||
assert!(processed.quality_score.is_some());
|
||||
assert!(processed.chunks.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
async fn test_pipeline_with_keyword_extraction() {
|
||||
crate::plugins::registry::get_validator_registry()
|
||||
.write()
|
||||
.shutdown_all()
|
||||
.unwrap();
|
||||
crate::plugins::registry::get_post_processor_registry()
|
||||
.write()
|
||||
.shutdown_all()
|
||||
.unwrap();
|
||||
|
||||
// Register keyword processor directly (bypasses Lazy which only runs once per process)
|
||||
let _ = crate::keywords::register_keyword_processor();
|
||||
clear_processor_cache().unwrap();
|
||||
|
||||
let doc = make_doc(
|
||||
r#"
|
||||
Machine learning is a branch of artificial intelligence that focuses on
|
||||
building systems that can learn from data. Deep learning is a subset of
|
||||
machine learning that uses neural networks with multiple layers.
|
||||
Natural language processing enables computers to understand human language.
|
||||
"#,
|
||||
"text/plain",
|
||||
);
|
||||
#[cfg(feature = "keywords-yake")]
|
||||
let keyword_config = crate::keywords::KeywordConfig::yake();
|
||||
|
||||
#[cfg(all(feature = "keywords-rake", not(feature = "keywords-yake")))]
|
||||
let keyword_config = crate::keywords::KeywordConfig::rake();
|
||||
|
||||
let config = ExtractionConfig {
|
||||
keywords: Some(keyword_config),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
|
||||
let keywords = processed
|
||||
.extracted_keywords
|
||||
.as_ref()
|
||||
.expect("Should have extracted keywords");
|
||||
assert!(!keywords.is_empty(), "Should have extracted keywords");
|
||||
|
||||
let first_keyword = &keywords[0];
|
||||
assert!(!first_keyword.text.is_empty());
|
||||
assert!(first_keyword.score > 0.0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
async fn test_pipeline_without_keyword_config() {
|
||||
let doc = make_doc("Machine learning and artificial intelligence.", "text/plain");
|
||||
|
||||
let config = ExtractionConfig {
|
||||
keywords: None,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
|
||||
assert!(!processed.metadata.additional.contains_key("keywords"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(any(feature = "keywords-yake", feature = "keywords-rake"))]
|
||||
async fn test_pipeline_keyword_extraction_short_content() {
|
||||
crate::plugins::registry::get_validator_registry()
|
||||
.write()
|
||||
.shutdown_all()
|
||||
.unwrap();
|
||||
crate::plugins::registry::get_post_processor_registry()
|
||||
.write()
|
||||
.shutdown_all()
|
||||
.unwrap();
|
||||
|
||||
let doc = make_doc("Short text", "text/plain");
|
||||
|
||||
#[cfg(feature = "keywords-yake")]
|
||||
let keyword_config = crate::keywords::KeywordConfig::yake();
|
||||
|
||||
#[cfg(all(feature = "keywords-rake", not(feature = "keywords-yake")))]
|
||||
let keyword_config = crate::keywords::KeywordConfig::rake();
|
||||
|
||||
let config = ExtractionConfig {
|
||||
keywords: Some(keyword_config),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
|
||||
assert!(!processed.metadata.additional.contains_key("keywords"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_postprocessor_runs_before_validator() {
|
||||
use crate::plugins::{Plugin, PostProcessor, ProcessingStage, Validator};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct TestPostProcessor;
|
||||
impl Plugin for TestPostProcessor {
|
||||
fn name(&self) -> &str {
|
||||
"test-processor"
|
||||
}
|
||||
fn version(&self) -> String {
|
||||
"1.0.0".to_string()
|
||||
}
|
||||
fn initialize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PostProcessor for TestPostProcessor {
|
||||
async fn process(&self, result: &mut ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
|
||||
result
|
||||
.metadata
|
||||
.additional
|
||||
.insert(Cow::Borrowed("processed"), serde_json::json!(true));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn processing_stage(&self) -> ProcessingStage {
|
||||
ProcessingStage::Middle
|
||||
}
|
||||
}
|
||||
|
||||
struct TestValidator;
|
||||
impl Plugin for TestValidator {
|
||||
fn name(&self) -> &str {
|
||||
"test-validator"
|
||||
}
|
||||
fn version(&self) -> String {
|
||||
"1.0.0".to_string()
|
||||
}
|
||||
fn initialize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Validator for TestValidator {
|
||||
async fn validate(&self, result: &ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
|
||||
let should_validate = result
|
||||
.metadata
|
||||
.additional
|
||||
.get(VALIDATION_MARKER_KEY)
|
||||
.and_then(|v| v.as_str())
|
||||
== Some(POSTPROCESSOR_VALIDATION_MARKER);
|
||||
|
||||
if !should_validate {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let processed = result
|
||||
.metadata
|
||||
.additional
|
||||
.get("processed")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if !processed {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: "Post-processor did not run before validator".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let pp_registry = crate::plugins::registry::get_post_processor_registry();
|
||||
let val_registry = crate::plugins::registry::get_validator_registry();
|
||||
|
||||
clear_processor_cache().unwrap();
|
||||
pp_registry.write().shutdown_all().unwrap();
|
||||
val_registry.write().shutdown_all().unwrap();
|
||||
clear_processor_cache().unwrap();
|
||||
|
||||
{
|
||||
let mut registry = pp_registry.write();
|
||||
registry.register(Arc::new(TestPostProcessor)).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let mut registry = val_registry.write();
|
||||
registry.register(Arc::new(TestValidator)).unwrap();
|
||||
}
|
||||
|
||||
clear_processor_cache().unwrap();
|
||||
|
||||
let mut doc = make_doc("test", "text/plain");
|
||||
doc.metadata.additional.insert(
|
||||
Cow::Borrowed(VALIDATION_MARKER_KEY),
|
||||
serde_json::json!(POSTPROCESSOR_VALIDATION_MARKER),
|
||||
);
|
||||
|
||||
let config = ExtractionConfig {
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: true,
|
||||
enabled_set: None,
|
||||
disabled_set: None,
|
||||
enabled_processors: None,
|
||||
disabled_processors: None,
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await;
|
||||
|
||||
pp_registry.write().shutdown_all().unwrap();
|
||||
val_registry.write().shutdown_all().unwrap();
|
||||
|
||||
assert!(processed.is_ok(), "Validator should have seen post-processor metadata");
|
||||
let processed = processed.unwrap();
|
||||
assert_eq!(
|
||||
processed.metadata.additional.get("processed"),
|
||||
Some(&serde_json::json!(true)),
|
||||
"Post-processor metadata should be present"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(feature = "quality")]
|
||||
async fn test_quality_processing_runs_before_validator() {
|
||||
ensure_quality_processor();
|
||||
use crate::plugins::{Plugin, Validator};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct QualityValidator;
|
||||
impl Plugin for QualityValidator {
|
||||
fn name(&self) -> &str {
|
||||
"quality-validator"
|
||||
}
|
||||
fn version(&self) -> String {
|
||||
"1.0.0".to_string()
|
||||
}
|
||||
fn initialize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Validator for QualityValidator {
|
||||
async fn validate(&self, result: &ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
|
||||
let should_validate = result
|
||||
.metadata
|
||||
.additional
|
||||
.get(VALIDATION_MARKER_KEY)
|
||||
.and_then(|v| v.as_str())
|
||||
== Some(QUALITY_VALIDATION_MARKER);
|
||||
|
||||
if !should_validate {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
if result.quality_score.is_none() {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: "Quality processing did not run before validator".to_string(),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let val_registry = crate::plugins::registry::get_validator_registry();
|
||||
{
|
||||
let mut registry = val_registry.write();
|
||||
registry.register(Arc::new(QualityValidator)).unwrap();
|
||||
}
|
||||
|
||||
let mut doc = make_doc("This is meaningful test content for quality scoring.", "text/plain");
|
||||
doc.metadata.additional.insert(
|
||||
Cow::Borrowed(VALIDATION_MARKER_KEY),
|
||||
serde_json::json!(QUALITY_VALIDATION_MARKER),
|
||||
);
|
||||
|
||||
let config = ExtractionConfig {
|
||||
enable_quality_processing: true,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await;
|
||||
|
||||
{
|
||||
let mut registry = val_registry.write();
|
||||
registry.remove("quality-validator").unwrap();
|
||||
}
|
||||
|
||||
assert!(processed.is_ok(), "Validator should have seen quality_score");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_multiple_postprocessors_run_before_validator() {
|
||||
use crate::plugins::{Plugin, PostProcessor, ProcessingStage, Validator};
|
||||
use async_trait::async_trait;
|
||||
use std::sync::Arc;
|
||||
|
||||
struct EarlyProcessor;
|
||||
impl Plugin for EarlyProcessor {
|
||||
fn name(&self) -> &str {
|
||||
"early-proc"
|
||||
}
|
||||
fn version(&self) -> String {
|
||||
"1.0.0".to_string()
|
||||
}
|
||||
fn initialize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PostProcessor for EarlyProcessor {
|
||||
async fn process(&self, result: &mut ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
|
||||
let mut order = result
|
||||
.metadata
|
||||
.additional
|
||||
.get("execution_order")
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
order.push(serde_json::json!("early"));
|
||||
result
|
||||
.metadata
|
||||
.additional
|
||||
.insert(Cow::Borrowed("execution_order"), serde_json::json!(order));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn processing_stage(&self) -> ProcessingStage {
|
||||
ProcessingStage::Early
|
||||
}
|
||||
}
|
||||
|
||||
struct LateProcessor;
|
||||
impl Plugin for LateProcessor {
|
||||
fn name(&self) -> &str {
|
||||
"late-proc"
|
||||
}
|
||||
fn version(&self) -> String {
|
||||
"1.0.0".to_string()
|
||||
}
|
||||
fn initialize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl PostProcessor for LateProcessor {
|
||||
async fn process(&self, result: &mut ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
|
||||
let mut order = result
|
||||
.metadata
|
||||
.additional
|
||||
.get("execution_order")
|
||||
.and_then(|v| v.as_array())
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
order.push(serde_json::json!("late"));
|
||||
result
|
||||
.metadata
|
||||
.additional
|
||||
.insert(Cow::Borrowed("execution_order"), serde_json::json!(order));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn processing_stage(&self) -> ProcessingStage {
|
||||
ProcessingStage::Late
|
||||
}
|
||||
}
|
||||
|
||||
struct OrderValidator;
|
||||
impl Plugin for OrderValidator {
|
||||
fn name(&self) -> &str {
|
||||
"order-validator"
|
||||
}
|
||||
fn version(&self) -> String {
|
||||
"1.0.0".to_string()
|
||||
}
|
||||
fn initialize(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn shutdown(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Validator for OrderValidator {
|
||||
async fn validate(&self, result: &ExtractionResult, _config: &ExtractionConfig) -> Result<()> {
|
||||
let should_validate = result
|
||||
.metadata
|
||||
.additional
|
||||
.get(VALIDATION_MARKER_KEY)
|
||||
.and_then(|v| v.as_str())
|
||||
== Some(ORDER_VALIDATION_MARKER);
|
||||
|
||||
if !should_validate {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let order = result
|
||||
.metadata
|
||||
.additional
|
||||
.get("execution_order")
|
||||
.and_then(|v| v.as_array())
|
||||
.ok_or_else(|| crate::KreuzbergError::Validation {
|
||||
message: "No execution order found".to_string(),
|
||||
source: None,
|
||||
})?;
|
||||
|
||||
if order.len() != 2 {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: format!("Expected 2 processors to run, got {}", order.len()),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
if order[0] != "early" || order[1] != "late" {
|
||||
return Err(crate::KreuzbergError::Validation {
|
||||
message: format!("Wrong execution order: {:?}", order),
|
||||
source: None,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
let pp_registry = crate::plugins::registry::get_post_processor_registry();
|
||||
let val_registry = crate::plugins::registry::get_validator_registry();
|
||||
|
||||
pp_registry.write().shutdown_all().unwrap();
|
||||
val_registry.write().shutdown_all().unwrap();
|
||||
clear_processor_cache().unwrap();
|
||||
|
||||
{
|
||||
let mut registry = pp_registry.write();
|
||||
registry.register(Arc::new(EarlyProcessor)).unwrap();
|
||||
registry.register(Arc::new(LateProcessor)).unwrap();
|
||||
}
|
||||
|
||||
{
|
||||
let mut registry = val_registry.write();
|
||||
registry.register(Arc::new(OrderValidator)).unwrap();
|
||||
}
|
||||
|
||||
// Clear the cache after registering new processors so it rebuilds with the test processors
|
||||
clear_processor_cache().unwrap();
|
||||
|
||||
let doc = make_doc("test", "text/plain");
|
||||
|
||||
let config = ExtractionConfig::default();
|
||||
|
||||
let processed = run_pipeline(doc, &config).await;
|
||||
|
||||
pp_registry.write().shutdown_all().unwrap();
|
||||
val_registry.write().shutdown_all().unwrap();
|
||||
clear_processor_cache().unwrap();
|
||||
|
||||
assert!(processed.is_ok(), "All processors should run before validator");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_run_pipeline_with_output_format_plain() {
|
||||
let doc = make_doc("test content", "text/plain");
|
||||
|
||||
let config = crate::core::config::ExtractionConfig {
|
||||
output_format: OutputFormat::Plain,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert_eq!(processed.content, "test content");
|
||||
assert_eq!(processed.metadata.output_format, Some("plain".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_run_pipeline_with_output_format_djot() {
|
||||
let doc = make_doc("test content", "text/djot");
|
||||
|
||||
let config = crate::core::config::ExtractionConfig {
|
||||
output_format: OutputFormat::Djot,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
// The content should still be present
|
||||
assert!(!processed.content.is_empty());
|
||||
assert_eq!(processed.metadata.output_format, Some("djot".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_run_pipeline_with_output_format_html() {
|
||||
let doc = make_doc("test content", "text/plain");
|
||||
|
||||
let config = crate::core::config::ExtractionConfig {
|
||||
output_format: OutputFormat::Html,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
// HTML renderer produces semantic tags from InternalDocument
|
||||
assert!(processed.content.contains("test content"));
|
||||
assert_eq!(processed.metadata.output_format, Some("html".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(feature = "quality")]
|
||||
async fn test_nfc_normalization_decomposes_to_composed() {
|
||||
// NFC normalization should convert decomposed characters to composed form.
|
||||
// "e\u{0301}" (e + combining acute accent) → "\u{00e9}" (é precomposed)
|
||||
let doc = make_doc("caf\u{0065}\u{0301}", "text/plain"); // "café" with decomposed é
|
||||
let config = ExtractionConfig {
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert_eq!(processed.content, "caf\u{00e9}"); // composed é
|
||||
assert!(!processed.content.contains('\u{0301}')); // no combining accent
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(feature = "quality")]
|
||||
async fn test_nfc_normalization_idempotent_on_ascii() {
|
||||
// NFC on already-normalized/ASCII text should be a no-op.
|
||||
let doc = make_doc("Hello, world! 123", "text/plain");
|
||||
let config = ExtractionConfig {
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
assert_eq!(processed.content, "Hello, world! 123");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(feature = "quality")]
|
||||
async fn test_nfc_normalization_applies_to_page_content() {
|
||||
// Create a doc with a page-1 element containing decomposed characters
|
||||
let mut doc = InternalDocument::new("plain");
|
||||
doc.mime_type = "text/plain".to_string();
|
||||
doc.push_element(InternalElement::text(ElementKind::Paragraph, "re\u{0301}sume\u{0301}", 0).with_page(1));
|
||||
let config = ExtractionConfig {
|
||||
postprocessor: Some(crate::core::config::PostProcessorConfig {
|
||||
enabled: false,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
// Content derived from page element
|
||||
assert!(processed.content.contains("r\u{00e9}sum\u{00e9}"));
|
||||
let pages = processed.pages.unwrap();
|
||||
assert_eq!(pages[0].content, "r\u{00e9}sum\u{00e9}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_run_pipeline_applies_output_format_last() {
|
||||
// This test verifies that output format is applied after all other processing
|
||||
let doc = make_doc("test", "text/plain");
|
||||
|
||||
let config = crate::core::config::ExtractionConfig {
|
||||
output_format: OutputFormat::Djot,
|
||||
// Disable other processing to ensure pipeline runs cleanly
|
||||
enable_quality_processing: false,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let processed = run_pipeline(doc, &config).await.unwrap();
|
||||
// The result should have gone through the pipeline successfully
|
||||
assert_eq!(processed.metadata.output_format, Some("djot".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
#[cfg(all(feature = "pdf", feature = "chunking"))]
|
||||
async fn test_chunking_populates_page_numbers_for_pdf() {
|
||||
use crate::core::config::ChunkingConfig;
|
||||
|
||||
let pdf_path =
|
||||
std::path::Path::new(env!("CARGO_MANIFEST_DIR")).join("../../test_documents/pdf/issue-636-chunk-pages.pdf");
|
||||
|
||||
if !pdf_path.exists() {
|
||||
// Skip if test document not available
|
||||
return;
|
||||
}
|
||||
|
||||
let pdf_bytes = std::fs::read(&pdf_path).unwrap();
|
||||
|
||||
// Configure chunking WITHOUT explicit pages config (the default user scenario)
|
||||
let config = ExtractionConfig {
|
||||
chunking: Some(ChunkingConfig {
|
||||
max_characters: 500,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
let result = crate::core::extractor::extract_bytes(&pdf_bytes, "application/pdf", &config)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Chunks should exist
|
||||
assert!(result.chunks.is_some(), "Chunks should be produced");
|
||||
let chunks = result.chunks.as_ref().unwrap();
|
||||
assert!(!chunks.is_empty(), "Should have at least one chunk");
|
||||
|
||||
// At least some chunks should have page numbers
|
||||
let chunks_with_pages = chunks.iter().filter(|c| c.metadata.first_page.is_some()).count();
|
||||
assert!(
|
||||
chunks_with_pages > 0,
|
||||
"At least some chunks should have page numbers, but none do. Total chunks: {}",
|
||||
chunks.len()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_append_ocr_text_for_pptx_images() {
|
||||
use crate::types::ExtractedImage;
|
||||
use crate::types::internal::{ElementKind, InternalDocument, InternalElement};
|
||||
use std::borrow::Cow;
|
||||
|
||||
let mut doc = InternalDocument::new("pptx");
|
||||
doc.append_ocr_text = true;
|
||||
doc.elements
|
||||
.push(InternalElement::text(ElementKind::Paragraph, "Before image.", 0));
|
||||
doc.elements.push(InternalElement::text(
|
||||
ElementKind::Paragraph,
|
||||
"",
|
||||
0,
|
||||
));
|
||||
doc.elements
|
||||
.push(InternalElement::text(ElementKind::Paragraph, "After image.", 0));
|
||||
|
||||
doc.images.push(ExtractedImage {
|
||||
data: bytes::Bytes::new(),
|
||||
format: Cow::Borrowed("jpeg"),
|
||||
image_index: 0,
|
||||
page_number: Some(1),
|
||||
width: Some(100),
|
||||
height: Some(100),
|
||||
colorspace: None,
|
||||
bits_per_component: None,
|
||||
is_mask: false,
|
||||
description: None,
|
||||
ocr_result: Some(Box::new(crate::types::ExtractionResult {
|
||||
content: "OCR text here".to_string(),
|
||||
mime_type: Cow::Borrowed("text/plain"),
|
||||
..Default::default()
|
||||
})),
|
||||
bounding_box: None,
|
||||
source_path: None,
|
||||
image_kind: None,
|
||||
kind_confidence: None,
|
||||
cluster_id: None,
|
||||
});
|
||||
|
||||
super::append_embedded_image_ocr_text(&mut doc);
|
||||
|
||||
assert_eq!(
|
||||
doc.elements.len(),
|
||||
4,
|
||||
"should have 4 elements (original 3 + 1 OCR paragraph)"
|
||||
);
|
||||
assert_eq!(doc.elements[2].text, "OCR text here");
|
||||
|
||||
let rendered = crate::rendering::render_markdown(&doc);
|
||||
assert!(rendered.contains("OCR text here"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
#[serial]
|
||||
async fn test_pdf_run_fallback_not_suppressed_without_images_config() {
|
||||
// When config.images is None, run_ocr_on_images must default to false so
|
||||
// the PDF document-level OCR fallback is NOT silently suppressed for
|
||||
// existing callers that never configured ImageExtractionConfig.
|
||||
use crate::core::config::ImageExtractionConfig;
|
||||
|
||||
let default_no_images = crate::core::config::ExtractionConfig::default();
|
||||
assert!(
|
||||
default_no_images.images.is_none(),
|
||||
"baseline: default config has no images section"
|
||||
);
|
||||
|
||||
let skip_fallback = default_no_images
|
||||
.images
|
||||
.as_ref()
|
||||
.map(|i| i.run_ocr_on_images)
|
||||
.unwrap_or(false);
|
||||
assert!(
|
||||
!skip_fallback,
|
||||
"RunFallback must NOT be suppressed when config.images is None"
|
||||
);
|
||||
|
||||
let with_images_opted_in = crate::core::config::ExtractionConfig {
|
||||
images: Some(ImageExtractionConfig {
|
||||
run_ocr_on_images: true,
|
||||
..Default::default()
|
||||
}),
|
||||
..Default::default()
|
||||
};
|
||||
let skip_fallback_opted_in = with_images_opted_in
|
||||
.images
|
||||
.as_ref()
|
||||
.map(|i| i.run_ocr_on_images)
|
||||
.unwrap_or(false);
|
||||
assert!(
|
||||
skip_fallback_opted_in,
|
||||
"RunFallback must be suppressed when images.run_ocr_on_images=true"
|
||||
);
|
||||
}
|
||||
76
crates/kreuzberg/src/core/server_config/env.rs
Normal file
76
crates/kreuzberg/src/core/server_config/env.rs
Normal file
@@ -0,0 +1,76 @@
|
||||
//! Environment variable overrides for server configuration.
|
||||
//!
|
||||
//! This module provides functionality to override server configuration values
|
||||
//! using environment variables. All settings can be overridden at runtime.
|
||||
|
||||
use crate::{KreuzbergError, Result};
|
||||
|
||||
/// Apply environment variable overrides to a ServerConfig.
|
||||
///
|
||||
/// Reads the following environment variables and overrides config values if set:
|
||||
///
|
||||
/// - `KREUZBERG_HOST` - Server host address
|
||||
/// - `KREUZBERG_PORT` - Server port number (parsed as u16)
|
||||
/// - `KREUZBERG_CORS_ORIGINS` - Comma-separated list of allowed origins
|
||||
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` - Max request body size in bytes
|
||||
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` - Max multipart field size in bytes
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if:
|
||||
/// - `KREUZBERG_PORT` cannot be parsed as u16
|
||||
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` cannot be parsed as usize
|
||||
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` cannot be parsed as usize
|
||||
pub(crate) fn apply_env_overrides(
|
||||
host: &mut String,
|
||||
port: &mut u16,
|
||||
cors_origins: &mut Vec<String>,
|
||||
max_request_body_bytes: &mut usize,
|
||||
max_multipart_field_bytes: &mut usize,
|
||||
) -> Result<()> {
|
||||
// Host override
|
||||
if let Ok(env_host) = std::env::var("KREUZBERG_HOST") {
|
||||
*host = env_host;
|
||||
}
|
||||
|
||||
// Port override
|
||||
if let Ok(port_str) = std::env::var("KREUZBERG_PORT") {
|
||||
*port = port_str.parse::<u16>().map_err(|e| {
|
||||
KreuzbergError::validation(format!(
|
||||
"KREUZBERG_PORT must be a valid u16 number, got '{}': {}",
|
||||
port_str, e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
// CORS origins override (comma-separated)
|
||||
if let Ok(origins_str) = std::env::var("KREUZBERG_CORS_ORIGINS") {
|
||||
*cors_origins = origins_str
|
||||
.split(',')
|
||||
.map(|s| s.trim().to_string())
|
||||
.filter(|s| !s.is_empty())
|
||||
.collect();
|
||||
}
|
||||
|
||||
// Max request body bytes override
|
||||
if let Ok(bytes_str) = std::env::var("KREUZBERG_MAX_REQUEST_BODY_BYTES") {
|
||||
*max_request_body_bytes = bytes_str.parse::<usize>().map_err(|e| {
|
||||
KreuzbergError::validation(format!(
|
||||
"KREUZBERG_MAX_REQUEST_BODY_BYTES must be a valid usize, got '{}': {}",
|
||||
bytes_str, e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
// Max multipart field bytes override
|
||||
if let Ok(bytes_str) = std::env::var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES") {
|
||||
*max_multipart_field_bytes = bytes_str.parse::<usize>().map_err(|e| {
|
||||
KreuzbergError::validation(format!(
|
||||
"KREUZBERG_MAX_MULTIPART_FIELD_BYTES must be a valid usize, got '{}': {}",
|
||||
bytes_str, e
|
||||
))
|
||||
})?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
193
crates/kreuzberg/src/core/server_config/loader.rs
Normal file
193
crates/kreuzberg/src/core/server_config/loader.rs
Normal file
@@ -0,0 +1,193 @@
|
||||
//! File loading logic for server configuration.
|
||||
//!
|
||||
//! This module provides functionality to load server configuration from various
|
||||
//! file formats (TOML, YAML, JSON) with support for both flat and nested formats.
|
||||
|
||||
use crate::{KreuzbergError, Result};
|
||||
use serde::Deserialize;
|
||||
use std::path::Path;
|
||||
|
||||
use super::ServerConfig;
|
||||
|
||||
/// Load server configuration from a file.
|
||||
///
|
||||
/// Automatically detects the file format based on extension:
|
||||
/// - `.toml` - TOML format
|
||||
/// - `.yaml` or `.yml` - YAML format
|
||||
/// - `.json` - JSON format
|
||||
///
|
||||
/// This function handles two config file formats:
|
||||
/// 1. Flat format: Server config at root level
|
||||
/// 2. Nested format: Server config under `[server]` section (combined with ExtractionConfig)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the configuration file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if:
|
||||
/// - File doesn't exist or cannot be read
|
||||
/// - File extension is not recognized
|
||||
/// - File content is invalid for the detected format
|
||||
pub(crate) fn from_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
|
||||
let path = path.as_ref();
|
||||
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
|
||||
|
||||
let extension = path.extension().and_then(|ext| ext.to_str()).ok_or_else(|| {
|
||||
KreuzbergError::validation(format!(
|
||||
"Cannot determine file format: no extension found in {}",
|
||||
path.display()
|
||||
))
|
||||
})?;
|
||||
|
||||
let config = match extension.to_lowercase().as_str() {
|
||||
"toml" => from_toml_str(&content, path)?,
|
||||
"yaml" | "yml" => from_yaml_str(&content, path)?,
|
||||
"json" => from_json_str(&content, path)?,
|
||||
_ => {
|
||||
return Err(KreuzbergError::validation(format!(
|
||||
"Unsupported config file format: .{}. Supported formats: .toml, .yaml, .yml, .json",
|
||||
extension
|
||||
)));
|
||||
}
|
||||
};
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Load server configuration from a TOML file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the TOML file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid TOML.
|
||||
pub(crate) fn from_toml_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
|
||||
let path = path.as_ref();
|
||||
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
|
||||
|
||||
let config: ServerConfig = toml::from_str(&content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)))?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Load server configuration from a YAML file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the YAML file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid YAML.
|
||||
pub(crate) fn from_yaml_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
|
||||
let path = path.as_ref();
|
||||
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
|
||||
|
||||
let config: ServerConfig = serde_yaml_ng::from_str(&content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)))?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
/// Load server configuration from a JSON file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the JSON file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid JSON.
|
||||
pub(crate) fn from_json_file(path: impl AsRef<Path>) -> Result<ServerConfig> {
|
||||
let path = path.as_ref();
|
||||
|
||||
let content = std::fs::read_to_string(path)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Failed to read config file {}: {}", path.display(), e)))?;
|
||||
|
||||
let config: ServerConfig = serde_json::from_str(&content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)))?;
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
// Helper functions for parsing different formats
|
||||
|
||||
fn from_toml_str(content: &str, path: &Path) -> Result<ServerConfig> {
|
||||
// Try nested format first (with [server] section)
|
||||
#[derive(Deserialize)]
|
||||
struct RootConfig {
|
||||
#[serde(default)]
|
||||
server: Option<ServerConfig>,
|
||||
}
|
||||
|
||||
if let Ok(root) = toml::from_str::<RootConfig>(content) {
|
||||
if let Some(server) = root.server {
|
||||
return Ok(server);
|
||||
} else {
|
||||
// No [server] section, try flat format
|
||||
return toml::from_str::<ServerConfig>(content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)));
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to flat format
|
||||
toml::from_str::<ServerConfig>(content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid TOML in {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
fn from_yaml_str(content: &str, path: &Path) -> Result<ServerConfig> {
|
||||
// Try nested format first (with server: section)
|
||||
#[derive(Deserialize)]
|
||||
struct RootConfig {
|
||||
#[serde(default)]
|
||||
server: Option<ServerConfig>,
|
||||
}
|
||||
|
||||
if let Ok(root) = serde_yaml_ng::from_str::<RootConfig>(content) {
|
||||
if let Some(server) = root.server {
|
||||
return Ok(server);
|
||||
} else {
|
||||
// No server section, try flat format
|
||||
return serde_yaml_ng::from_str::<ServerConfig>(content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)));
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to flat format
|
||||
serde_yaml_ng::from_str::<ServerConfig>(content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid YAML in {}: {}", path.display(), e)))
|
||||
}
|
||||
|
||||
fn from_json_str(content: &str, path: &Path) -> Result<ServerConfig> {
|
||||
// Try nested format first (with "server" key)
|
||||
#[derive(Deserialize)]
|
||||
struct RootConfig {
|
||||
#[serde(default)]
|
||||
server: Option<ServerConfig>,
|
||||
}
|
||||
|
||||
if let Ok(root) = serde_json::from_str::<RootConfig>(content) {
|
||||
if let Some(server) = root.server {
|
||||
return Ok(server);
|
||||
} else {
|
||||
// No server key, try flat format
|
||||
return serde_json::from_str::<ServerConfig>(content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)));
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to flat format
|
||||
serde_json::from_str::<ServerConfig>(content)
|
||||
.map_err(|e| KreuzbergError::validation(format!("Invalid JSON in {}: {}", path.display(), e)))
|
||||
}
|
||||
357
crates/kreuzberg/src/core/server_config/mod.rs
Normal file
357
crates/kreuzberg/src/core/server_config/mod.rs
Normal file
@@ -0,0 +1,357 @@
|
||||
//! Server configuration for the Kreuzberg API.
|
||||
//!
|
||||
//! This module provides the `ServerConfig` struct for managing API server settings
|
||||
//! including host, port, CORS, and upload size limits. Configuration can be loaded
|
||||
//! from TOML, YAML, or JSON files and can be overridden by environment variables.
|
||||
//!
|
||||
//! # Features
|
||||
//!
|
||||
//! - **Multi-format support**: Load configuration from TOML, YAML, or JSON files
|
||||
//! - **Environment overrides**: All settings can be overridden via environment variables
|
||||
//! - **Sensible defaults**: All fields have reasonable defaults matching current behavior
|
||||
//! - **Flexible CORS**: Support for all origins (default) or specific origin lists
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust,no_run
|
||||
//! use kreuzberg::core::ServerConfig;
|
||||
//!
|
||||
//! # fn example() -> kreuzberg::Result<()> {
|
||||
//! // Create with defaults
|
||||
//! let mut config = ServerConfig::default();
|
||||
//!
|
||||
//! // Or load from file
|
||||
//! let mut config = ServerConfig::from_file("kreuzberg.toml")?;
|
||||
//!
|
||||
//! // Apply environment variable overrides
|
||||
//! config.apply_env_overrides()?;
|
||||
//!
|
||||
//! # Ok(())
|
||||
//! # }
|
||||
//! ```
|
||||
|
||||
use crate::Result;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::Path;
|
||||
|
||||
mod env;
|
||||
mod loader;
|
||||
mod validation;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests;
|
||||
|
||||
/// Default host address for API server
|
||||
const DEFAULT_HOST: &str = "127.0.0.1";
|
||||
|
||||
/// Default port for API server
|
||||
const DEFAULT_PORT: u16 = 8000;
|
||||
|
||||
/// Default maximum request body size: 100 MB
|
||||
const DEFAULT_MAX_REQUEST_BODY_BYTES: usize = 104_857_600;
|
||||
|
||||
/// Default maximum multipart field size: 100 MB
|
||||
const DEFAULT_MAX_MULTIPART_FIELD_BYTES: usize = 104_857_600;
|
||||
|
||||
/// API server configuration.
|
||||
///
|
||||
/// This struct holds all configuration options for the Kreuzberg API server,
|
||||
/// including host/port settings, CORS configuration, and upload limits.
|
||||
///
|
||||
/// # Defaults
|
||||
///
|
||||
/// - `host`: "127.0.0.1" (localhost only)
|
||||
/// - `port`: 8000
|
||||
/// - `cors_origins`: empty vector (allows all origins)
|
||||
/// - `max_request_body_bytes`: 104_857_600 (100 MB)
|
||||
/// - `max_multipart_field_bytes`: 104_857_600 (100 MB)
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(default)]
|
||||
pub struct ServerConfig {
|
||||
/// Server host address (e.g., "127.0.0.1", "0.0.0.0")
|
||||
#[serde(default = "default_host")]
|
||||
pub host: String,
|
||||
|
||||
/// Server port number
|
||||
#[serde(default = "default_port")]
|
||||
pub port: u16,
|
||||
|
||||
/// CORS allowed origins. Empty vector means allow all origins.
|
||||
///
|
||||
/// If this is an empty vector, the server will accept requests from any origin.
|
||||
/// If populated with specific origins (e.g., `"https://example.com"`), only
|
||||
/// those origins will be allowed.
|
||||
#[serde(default)]
|
||||
pub cors_origins: Vec<String>,
|
||||
|
||||
/// Maximum size of request body in bytes (default: 100 MB)
|
||||
#[serde(default = "default_max_request_body_bytes")]
|
||||
pub max_request_body_bytes: usize,
|
||||
|
||||
/// Maximum size of multipart fields in bytes (default: 100 MB)
|
||||
#[serde(default = "default_max_multipart_field_bytes")]
|
||||
pub max_multipart_field_bytes: usize,
|
||||
}
|
||||
|
||||
impl Default for ServerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
host: default_host(),
|
||||
port: default_port(),
|
||||
cors_origins: Vec::new(),
|
||||
max_request_body_bytes: default_max_request_body_bytes(),
|
||||
max_multipart_field_bytes: default_max_multipart_field_bytes(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Default value functions for serde
|
||||
fn default_host() -> String {
|
||||
DEFAULT_HOST.to_string()
|
||||
}
|
||||
|
||||
fn default_port() -> u16 {
|
||||
DEFAULT_PORT
|
||||
}
|
||||
|
||||
fn default_max_request_body_bytes() -> usize {
|
||||
DEFAULT_MAX_REQUEST_BODY_BYTES
|
||||
}
|
||||
|
||||
fn default_max_multipart_field_bytes() -> usize {
|
||||
DEFAULT_MAX_MULTIPART_FIELD_BYTES
|
||||
}
|
||||
|
||||
impl ServerConfig {
|
||||
/// Create a new `ServerConfig` with default values.
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
/// Get the server listen address (host:port).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// let config = ServerConfig::default();
|
||||
/// assert_eq!(config.listen_addr(), "127.0.0.1:8000");
|
||||
/// ```
|
||||
pub fn listen_addr(&self) -> String {
|
||||
format!("{}:{}", self.host, self.port)
|
||||
}
|
||||
|
||||
/// Check if CORS allows all origins.
|
||||
///
|
||||
/// Returns `true` if the `cors_origins` vector is empty, meaning all origins
|
||||
/// are allowed. Returns `false` if specific origins are configured.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// let mut config = ServerConfig::default();
|
||||
/// assert!(config.cors_allows_all());
|
||||
///
|
||||
/// config.cors_origins.push("https://example.com".to_string());
|
||||
/// assert!(!config.cors_allows_all());
|
||||
/// ```
|
||||
pub fn cors_allows_all(&self) -> bool {
|
||||
self.cors_origins.is_empty()
|
||||
}
|
||||
|
||||
/// Check if a given origin is allowed by CORS configuration.
|
||||
///
|
||||
/// Returns `true` if:
|
||||
/// - CORS allows all origins (empty origins list), or
|
||||
/// - The given origin is in the allowed origins list
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `origin` - The origin to check (e.g., "https://example.com")
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// let mut config = ServerConfig::default();
|
||||
/// assert!(config.is_origin_allowed("https://example.com"));
|
||||
///
|
||||
/// config.cors_origins.push("https://allowed.com".to_string());
|
||||
/// assert!(config.is_origin_allowed("https://allowed.com"));
|
||||
/// assert!(!config.is_origin_allowed("https://denied.com"));
|
||||
/// ```
|
||||
pub fn is_origin_allowed(&self, origin: &str) -> bool {
|
||||
self.cors_origins.is_empty() || self.cors_origins.contains(&origin.to_string())
|
||||
}
|
||||
|
||||
/// Get maximum request body size in megabytes (rounded up).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// let mut config = ServerConfig::default();
|
||||
/// assert_eq!(config.max_request_body_mb(), 100);
|
||||
/// ```
|
||||
pub fn max_request_body_mb(&self) -> usize {
|
||||
self.max_request_body_bytes.div_ceil(1_048_576)
|
||||
}
|
||||
|
||||
/// Get maximum multipart field size in megabytes (rounded up).
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// let mut config = ServerConfig::default();
|
||||
/// assert_eq!(config.max_multipart_field_mb(), 100);
|
||||
/// ```
|
||||
pub fn max_multipart_field_mb(&self) -> usize {
|
||||
self.max_multipart_field_bytes.div_ceil(1_048_576)
|
||||
}
|
||||
|
||||
/// Apply environment variable overrides to the configuration.
|
||||
///
|
||||
/// Reads the following environment variables and overrides config values if set:
|
||||
///
|
||||
/// - `KREUZBERG_HOST` - Server host address
|
||||
/// - `KREUZBERG_PORT` - Server port number (parsed as u16)
|
||||
/// - `KREUZBERG_CORS_ORIGINS` - Comma-separated list of allowed origins
|
||||
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` - Max request body size in bytes
|
||||
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` - Max multipart field size in bytes
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if:
|
||||
/// - `KREUZBERG_PORT` cannot be parsed as u16
|
||||
/// - `KREUZBERG_MAX_REQUEST_BODY_BYTES` cannot be parsed as usize
|
||||
/// - `KREUZBERG_MAX_MULTIPART_FIELD_BYTES` cannot be parsed as usize
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// # fn example() -> kreuzberg::Result<()> {
|
||||
/// unsafe {
|
||||
/// std::env::set_var("KREUZBERG_HOST", "0.0.0.0");
|
||||
/// std::env::set_var("KREUZBERG_PORT", "3000");
|
||||
/// }
|
||||
///
|
||||
/// let mut config = ServerConfig::default();
|
||||
/// config.apply_env_overrides()?;
|
||||
///
|
||||
/// assert_eq!(config.host, "0.0.0.0");
|
||||
/// assert_eq!(config.port, 3000);
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn apply_env_overrides(&mut self) -> Result<()> {
|
||||
env::apply_env_overrides(
|
||||
&mut self.host,
|
||||
&mut self.port,
|
||||
&mut self.cors_origins,
|
||||
&mut self.max_request_body_bytes,
|
||||
&mut self.max_multipart_field_bytes,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Load server configuration from a file.
|
||||
///
|
||||
/// Automatically detects the file format based on extension:
|
||||
/// - `.toml` - TOML format
|
||||
/// - `.yaml` or `.yml` - YAML format
|
||||
/// - `.json` - JSON format
|
||||
///
|
||||
/// This function handles two config file formats:
|
||||
/// 1. Flat format: Server config at root level
|
||||
/// 2. Nested format: Server config under `[server]` section (combined with ExtractionConfig)
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the configuration file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if:
|
||||
/// - File doesn't exist or cannot be read
|
||||
/// - File extension is not recognized
|
||||
/// - File content is invalid for the detected format
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// # fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ServerConfig::from_file("kreuzberg.toml")?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
loader::from_file(path)
|
||||
}
|
||||
|
||||
/// Load server configuration from a TOML file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the TOML file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid TOML.
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```rust,no_run
|
||||
/// use kreuzberg::core::ServerConfig;
|
||||
///
|
||||
/// # fn example() -> kreuzberg::Result<()> {
|
||||
/// let config = ServerConfig::from_toml_file("kreuzberg.toml")?;
|
||||
/// # Ok(())
|
||||
/// # }
|
||||
/// ```
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn from_toml_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
loader::from_toml_file(path)
|
||||
}
|
||||
|
||||
/// Load server configuration from a YAML file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the YAML file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid YAML.
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn from_yaml_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
loader::from_yaml_file(path)
|
||||
}
|
||||
|
||||
/// Load server configuration from a JSON file.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `path` - Path to the JSON file
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns `KreuzbergError::Validation` if the file doesn't exist or is invalid JSON.
|
||||
#[cfg_attr(alef, alef(skip))]
|
||||
pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self> {
|
||||
loader::from_json_file(path)
|
||||
}
|
||||
}
|
||||
87
crates/kreuzberg/src/core/server_config/tests/basic_tests.rs
Normal file
87
crates/kreuzberg/src/core/server_config/tests/basic_tests.rs
Normal file
@@ -0,0 +1,87 @@
|
||||
//! Basic tests for ServerConfig functionality.
|
||||
|
||||
use crate::core::ServerConfig;
|
||||
|
||||
#[test]
|
||||
fn test_default_config() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.host, "127.0.0.1");
|
||||
assert_eq!(config.port, 8000);
|
||||
assert!(config.cors_origins.is_empty());
|
||||
assert_eq!(config.max_request_body_bytes, 104_857_600);
|
||||
assert_eq!(config.max_multipart_field_bytes, 104_857_600);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_listen_addr() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.listen_addr(), "127.0.0.1:8000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_listen_addr_custom() {
|
||||
let config = ServerConfig {
|
||||
host: "0.0.0.0".to_string(),
|
||||
port: 3000,
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.listen_addr(), "0.0.0.0:3000");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cors_allows_all() {
|
||||
let mut config = ServerConfig::default();
|
||||
assert!(config.cors_allows_all());
|
||||
|
||||
config.cors_origins.push("https://example.com".to_string());
|
||||
assert!(!config.cors_allows_all());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_origin_allowed_all() {
|
||||
let config = ServerConfig::default();
|
||||
assert!(config.is_origin_allowed("https://example.com"));
|
||||
assert!(config.is_origin_allowed("https://other.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_is_origin_allowed_specific() {
|
||||
let mut config = ServerConfig::default();
|
||||
config.cors_origins.push("https://allowed.com".to_string());
|
||||
|
||||
assert!(config.is_origin_allowed("https://allowed.com"));
|
||||
assert!(!config.is_origin_allowed("https://denied.com"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_request_body_mb() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.max_request_body_mb(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_multipart_field_mb() {
|
||||
let config = ServerConfig::default();
|
||||
assert_eq!(config.max_multipart_field_mb(), 100);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_max_bytes_to_mb_rounding() {
|
||||
let mut config = ServerConfig {
|
||||
max_request_body_bytes: 1_048_576, // 1 MB
|
||||
..Default::default()
|
||||
};
|
||||
assert_eq!(config.max_request_body_mb(), 1);
|
||||
|
||||
config.max_request_body_bytes = 1_048_577; // 1 MB + 1 byte
|
||||
assert_eq!(config.max_request_body_mb(), 2); // Rounds up
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_serde_default_serialization() {
|
||||
let config = ServerConfig::default();
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
|
||||
assert!(json.contains("host"));
|
||||
assert!(json.contains("port"));
|
||||
}
|
||||
192
crates/kreuzberg/src/core/server_config/tests/env_tests.rs
Normal file
192
crates/kreuzberg/src/core/server_config/tests/env_tests.rs
Normal file
@@ -0,0 +1,192 @@
|
||||
//! Tests for environment variable overrides.
|
||||
|
||||
#![allow(unsafe_code)]
|
||||
|
||||
use crate::core::ServerConfig;
|
||||
|
||||
#[serial_test::serial]
|
||||
#[test]
|
||||
fn test_apply_env_host_override() {
|
||||
let original = std::env::var("KREUZBERG_HOST").ok();
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_HOST", "192.168.1.1");
|
||||
}
|
||||
|
||||
let mut config = ServerConfig::default();
|
||||
config.apply_env_overrides().unwrap();
|
||||
|
||||
assert_eq!(config.host, "192.168.1.1");
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
if let Some(orig) = original {
|
||||
std::env::set_var("KREUZBERG_HOST", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_HOST");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[serial_test::serial]
|
||||
#[test]
|
||||
fn test_apply_env_port_override() {
|
||||
let original = std::env::var("KREUZBERG_PORT").ok();
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_PORT", "5000");
|
||||
}
|
||||
|
||||
let mut config = ServerConfig::default();
|
||||
config.apply_env_overrides().unwrap();
|
||||
|
||||
assert_eq!(config.port, 5000);
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
if let Some(orig) = original {
|
||||
std::env::set_var("KREUZBERG_PORT", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_PORT");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[serial_test::serial]
|
||||
#[test]
|
||||
fn test_apply_env_port_invalid() {
|
||||
let original = std::env::var("KREUZBERG_PORT").ok();
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_PORT", "not_a_number");
|
||||
}
|
||||
|
||||
let mut config = ServerConfig::default();
|
||||
let result = config.apply_env_overrides();
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("KREUZBERG_PORT must be a valid u16")
|
||||
);
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
if let Some(orig) = original {
|
||||
std::env::set_var("KREUZBERG_PORT", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_PORT");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[serial_test::serial]
|
||||
#[test]
|
||||
fn test_apply_env_cors_origins_override() {
|
||||
let original = std::env::var("KREUZBERG_CORS_ORIGINS").ok();
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_CORS_ORIGINS", "https://example.com, https://other.com");
|
||||
}
|
||||
|
||||
let mut config = ServerConfig::default();
|
||||
config.apply_env_overrides().unwrap();
|
||||
|
||||
assert_eq!(config.cors_origins.len(), 2);
|
||||
assert!(config.cors_origins.contains(&"https://example.com".to_string()));
|
||||
assert!(config.cors_origins.contains(&"https://other.com".to_string()));
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
if let Some(orig) = original {
|
||||
std::env::set_var("KREUZBERG_CORS_ORIGINS", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_CORS_ORIGINS");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[serial_test::serial]
|
||||
#[test]
|
||||
fn test_apply_env_max_request_body_bytes_override() {
|
||||
let original = std::env::var("KREUZBERG_MAX_REQUEST_BODY_BYTES").ok();
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_MAX_REQUEST_BODY_BYTES", "52428800");
|
||||
}
|
||||
|
||||
let mut config = ServerConfig::default();
|
||||
config.apply_env_overrides().unwrap();
|
||||
|
||||
assert_eq!(config.max_request_body_bytes, 52_428_800);
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
if let Some(orig) = original {
|
||||
std::env::set_var("KREUZBERG_MAX_REQUEST_BODY_BYTES", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_MAX_REQUEST_BODY_BYTES");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[serial_test::serial]
|
||||
#[test]
|
||||
fn test_apply_env_max_multipart_field_bytes_override() {
|
||||
let original = std::env::var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES").ok();
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES", "78643200");
|
||||
}
|
||||
|
||||
let mut config = ServerConfig::default();
|
||||
config.apply_env_overrides().unwrap();
|
||||
|
||||
assert_eq!(config.max_multipart_field_bytes, 78_643_200);
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
if let Some(orig) = original {
|
||||
std::env::set_var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_MAX_MULTIPART_FIELD_BYTES");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[serial_test::serial]
|
||||
#[test]
|
||||
fn test_apply_env_multiple_overrides() {
|
||||
let host_orig = std::env::var("KREUZBERG_HOST").ok();
|
||||
let port_orig = std::env::var("KREUZBERG_PORT").ok();
|
||||
let cors_orig = std::env::var("KREUZBERG_CORS_ORIGINS").ok();
|
||||
|
||||
unsafe {
|
||||
std::env::set_var("KREUZBERG_HOST", "0.0.0.0");
|
||||
std::env::set_var("KREUZBERG_PORT", "4000");
|
||||
std::env::set_var("KREUZBERG_CORS_ORIGINS", "https://api.example.com");
|
||||
}
|
||||
|
||||
let mut config = ServerConfig::default();
|
||||
config.apply_env_overrides().unwrap();
|
||||
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 4000);
|
||||
assert_eq!(config.cors_origins.len(), 1);
|
||||
assert_eq!(config.cors_origins[0], "https://api.example.com");
|
||||
|
||||
// Cleanup
|
||||
unsafe {
|
||||
if let Some(orig) = host_orig {
|
||||
std::env::set_var("KREUZBERG_HOST", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_HOST");
|
||||
}
|
||||
if let Some(orig) = port_orig {
|
||||
std::env::set_var("KREUZBERG_PORT", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_PORT");
|
||||
}
|
||||
if let Some(orig) = cors_orig {
|
||||
std::env::set_var("KREUZBERG_CORS_ORIGINS", orig);
|
||||
} else {
|
||||
std::env::remove_var("KREUZBERG_CORS_ORIGINS");
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,321 @@
|
||||
//! Tests for file loading functionality.
|
||||
|
||||
use crate::core::ServerConfig;
|
||||
use std::fs;
|
||||
use tempfile::tempdir;
|
||||
|
||||
#[test]
|
||||
fn test_from_toml_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.toml");
|
||||
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host = "0.0.0.0"
|
||||
port = 3000
|
||||
cors_origins = ["https://example.com", "https://other.com"]
|
||||
max_request_body_bytes = 50000000
|
||||
max_multipart_field_bytes = 75000000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_toml_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
assert_eq!(config.cors_origins.len(), 2);
|
||||
assert_eq!(config.max_request_body_bytes, 50_000_000);
|
||||
assert_eq!(config.max_multipart_field_bytes, 75_000_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_yaml_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.yaml");
|
||||
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host: 0.0.0.0
|
||||
port: 3000
|
||||
cors_origins:
|
||||
- https://example.com
|
||||
- https://other.com
|
||||
max_request_body_bytes: 50000000
|
||||
max_multipart_field_bytes: 75000000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_yaml_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
assert_eq!(config.cors_origins.len(), 2);
|
||||
assert_eq!(config.max_request_body_bytes, 50_000_000);
|
||||
assert_eq!(config.max_multipart_field_bytes, 75_000_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_json_file() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.json");
|
||||
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"{
|
||||
"host": "0.0.0.0",
|
||||
"port": 3000,
|
||||
"cors_origins": ["https://example.com", "https://other.com"],
|
||||
"max_request_body_bytes": 50000000,
|
||||
"max_multipart_field_bytes": 75000000
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_json_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
assert_eq!(config.cors_origins.len(), 2);
|
||||
assert_eq!(config.max_request_body_bytes, 50_000_000);
|
||||
assert_eq!(config.max_multipart_field_bytes, 75_000_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_auto_detects_toml() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.toml");
|
||||
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host = "0.0.0.0"
|
||||
port = 3000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_auto_detects_yaml() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.yaml");
|
||||
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host: 0.0.0.0
|
||||
port: 3000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_auto_detects_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.json");
|
||||
|
||||
fs::write(&config_path, r#"{"host": "0.0.0.0", "port": 3000}"#).unwrap();
|
||||
|
||||
let config = ServerConfig::from_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_unsupported_extension() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.txt");
|
||||
|
||||
fs::write(&config_path, "host = 0.0.0.0").unwrap();
|
||||
|
||||
let result = ServerConfig::from_file(&config_path);
|
||||
assert!(result.is_err());
|
||||
assert!(
|
||||
result
|
||||
.unwrap_err()
|
||||
.to_string()
|
||||
.contains("Unsupported config file format")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_no_extension() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server");
|
||||
|
||||
fs::write(&config_path, "host = 0.0.0.0").unwrap();
|
||||
|
||||
let result = ServerConfig::from_file(&config_path);
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("no extension found"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cors_origins_empty_in_toml() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.toml");
|
||||
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host = "127.0.0.1"
|
||||
port = 8000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_toml_file(&config_path).unwrap();
|
||||
assert!(config.cors_origins.is_empty());
|
||||
assert!(config.cors_allows_all());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_full_configuration_toml() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.toml");
|
||||
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host = "192.168.1.100"
|
||||
port = 9000
|
||||
cors_origins = ["https://app1.com", "https://app2.com", "https://app3.com"]
|
||||
max_request_body_bytes = 200000000
|
||||
max_multipart_field_bytes = 150000000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_toml_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "192.168.1.100");
|
||||
assert_eq!(config.port, 9000);
|
||||
assert_eq!(config.listen_addr(), "192.168.1.100:9000");
|
||||
assert_eq!(config.cors_origins.len(), 3);
|
||||
assert!(!config.cors_allows_all());
|
||||
assert!(config.is_origin_allowed("https://app1.com"));
|
||||
assert!(!config.is_origin_allowed("https://app4.com"));
|
||||
assert_eq!(config.max_request_body_bytes, 200_000_000);
|
||||
assert_eq!(config.max_multipart_field_bytes, 150_000_000);
|
||||
assert_eq!(config.max_request_body_mb(), 191);
|
||||
assert_eq!(config.max_multipart_field_mb(), 144);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_with_nested_server_section_toml() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("kreuzberg.toml");
|
||||
|
||||
// Config file with [server] section and other sections (like ExtractionConfig)
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
[server]
|
||||
host = "0.0.0.0"
|
||||
port = 3000
|
||||
cors_origins = ["https://example.com"]
|
||||
|
||||
[ocr]
|
||||
backend = "tesseract"
|
||||
language = "eng"
|
||||
|
||||
[extraction]
|
||||
enabled = true
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 3000);
|
||||
assert_eq!(config.cors_origins.len(), 1);
|
||||
assert_eq!(config.cors_origins[0], "https://example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_with_nested_server_section_yaml() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("kreuzberg.yaml");
|
||||
|
||||
// Config file with server: section and other sections
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
server:
|
||||
host: 0.0.0.0
|
||||
port: 4000
|
||||
cors_origins:
|
||||
- https://example.com
|
||||
|
||||
ocr:
|
||||
backend: tesseract
|
||||
language: eng
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 4000);
|
||||
assert_eq!(config.cors_origins.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_with_nested_server_section_json() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("kreuzberg.json");
|
||||
|
||||
// Config file with "server" key and other sections
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
{
|
||||
"server": {
|
||||
"host": "0.0.0.0",
|
||||
"port": 5000,
|
||||
"cors_origins": ["https://example.com"]
|
||||
},
|
||||
"ocr": {
|
||||
"backend": "tesseract",
|
||||
"language": "eng"
|
||||
}
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "0.0.0.0");
|
||||
assert_eq!(config.port, 5000);
|
||||
assert_eq!(config.cors_origins.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_from_file_flat_format_still_works() {
|
||||
let dir = tempdir().unwrap();
|
||||
let config_path = dir.path().join("server.toml");
|
||||
|
||||
// Old flat format without [server] section
|
||||
fs::write(
|
||||
&config_path,
|
||||
r#"
|
||||
host = "192.168.1.1"
|
||||
port = 6000
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = ServerConfig::from_file(&config_path).unwrap();
|
||||
assert_eq!(config.host, "192.168.1.1");
|
||||
assert_eq!(config.port, 6000);
|
||||
}
|
||||
5
crates/kreuzberg/src/core/server_config/tests/mod.rs
Normal file
5
crates/kreuzberg/src/core/server_config/tests/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
//! Tests for server configuration module.
|
||||
|
||||
mod basic_tests;
|
||||
mod env_tests;
|
||||
mod file_loading_tests;
|
||||
4
crates/kreuzberg/src/core/server_config/validation.rs
Normal file
4
crates/kreuzberg/src/core/server_config/validation.rs
Normal file
@@ -0,0 +1,4 @@
|
||||
//! Validation and normalization for server configuration.
|
||||
//!
|
||||
//! This module provides functionality to validate and normalize server configuration
|
||||
//! values.
|
||||
Reference in New Issue
Block a user