This commit is contained in:
456
crates/kreuzberg/tests/api_embed.rs
Normal file
456
crates/kreuzberg/tests/api_embed.rs
Normal file
@@ -0,0 +1,456 @@
|
||||
//! Integration tests for the /embed API endpoint.
|
||||
|
||||
#![cfg(all(feature = "api", feature = "embeddings"))]
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{Request, StatusCode},
|
||||
};
|
||||
use serde_json::json;
|
||||
use tower::ServiceExt;
|
||||
|
||||
use kreuzberg::{
|
||||
ExtractionConfig,
|
||||
api::{EmbedResponse, create_router},
|
||||
};
|
||||
|
||||
/// Test embed endpoint with valid texts.
|
||||
#[tokio::test]
|
||||
async fn test_embed_valid_texts() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let request_body = json!({
|
||||
"texts": ["Hello world", "Second text"]
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to convert to bytes");
|
||||
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(embed_response.count, 2);
|
||||
assert_eq!(embed_response.embeddings.len(), 2);
|
||||
assert!(embed_response.dimensions > 0);
|
||||
assert!(!embed_response.model.is_empty());
|
||||
|
||||
// Verify embeddings have correct dimensions
|
||||
for embedding in &embed_response.embeddings {
|
||||
assert_eq!(embedding.len(), embed_response.dimensions);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test embed endpoint with empty texts array returns 400.
|
||||
#[tokio::test]
|
||||
async fn test_embed_empty_texts() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let request_body = json!({
|
||||
"texts": []
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
/// Test embed endpoint with custom embedding configuration.
|
||||
#[tokio::test]
|
||||
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
|
||||
async fn test_embed_with_custom_config() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let request_body = json!({
|
||||
"texts": ["Test embedding with custom config"],
|
||||
"config": {
|
||||
"model": {
|
||||
"type": "preset",
|
||||
"name": "fast"
|
||||
},
|
||||
"batch_size": 32
|
||||
}
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to convert to bytes");
|
||||
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(embed_response.count, 1);
|
||||
assert_eq!(embed_response.embeddings.len(), 1);
|
||||
assert_eq!(embed_response.model, "fast");
|
||||
}
|
||||
|
||||
/// Test embed endpoint with single text.
|
||||
#[tokio::test]
|
||||
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
|
||||
async fn test_embed_single_text() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let request_body = json!({
|
||||
"texts": ["Single text for embedding"]
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to convert to bytes");
|
||||
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(embed_response.count, 1);
|
||||
assert_eq!(embed_response.embeddings.len(), 1);
|
||||
}
|
||||
|
||||
/// Test embed endpoint with multiple texts (batch).
|
||||
#[tokio::test]
|
||||
async fn test_embed_batch() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let texts: Vec<String> = (0..10).map(|i| format!("Test text number {}", i)).collect();
|
||||
|
||||
let request_body = json!({
|
||||
"texts": texts
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to convert to bytes");
|
||||
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(embed_response.count, 10);
|
||||
assert_eq!(embed_response.embeddings.len(), 10);
|
||||
|
||||
// Verify all embeddings have the same dimensions
|
||||
let first_dim = embed_response.embeddings[0].len();
|
||||
for embedding in &embed_response.embeddings {
|
||||
assert_eq!(embedding.len(), first_dim);
|
||||
}
|
||||
}
|
||||
|
||||
/// Test embed endpoint with long text.
|
||||
#[tokio::test]
|
||||
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
|
||||
async fn test_embed_long_text() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let long_text = "Lorem ipsum dolor sit amet, consectetur adipiscing elit. ".repeat(100);
|
||||
|
||||
let request_body = json!({
|
||||
"texts": [long_text]
|
||||
});
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to convert to bytes");
|
||||
let embed_response: EmbedResponse = serde_json::from_slice(&body).expect("Failed to deserialize");
|
||||
|
||||
assert_eq!(embed_response.count, 1);
|
||||
assert_eq!(embed_response.embeddings.len(), 1);
|
||||
}
|
||||
|
||||
/// Test embed endpoint with malformed JSON returns 400.
|
||||
#[tokio::test]
|
||||
async fn test_embed_malformed_json() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from("{invalid json}"))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
}
|
||||
|
||||
/// Test embed endpoint rejects JSON array at root level.
|
||||
#[tokio::test]
|
||||
async fn test_embed_rejects_json_array() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
// Send a JSON array instead of object
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(r#"[["text1"], {"texts": ["text2"]}]"#))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
// Should reject with 400 or 422, NOT 200
|
||||
assert!(
|
||||
response.status() == StatusCode::BAD_REQUEST || response.status() == StatusCode::UNPROCESSABLE_ENTITY,
|
||||
"Expected 400 or 422, got {}",
|
||||
response.status()
|
||||
);
|
||||
}
|
||||
|
||||
/// Test embed endpoint rejects simple JSON array with strings.
|
||||
#[tokio::test]
|
||||
async fn test_embed_rejects_simple_json_array() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
// Send a simple string array instead of object with texts field
|
||||
let response = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(r#"["text1", "text2", "text3"]"#))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response.status(), StatusCode::BAD_REQUEST);
|
||||
|
||||
// Check that error response contains helpful message
|
||||
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to read response body");
|
||||
let error_response: serde_json::Value = serde_json::from_slice(&body).expect("Failed to parse error response");
|
||||
|
||||
assert!(
|
||||
error_response["message"]
|
||||
.as_str()
|
||||
.map(|msg| msg.contains("array") || msg.contains("object"))
|
||||
.unwrap_or(false)
|
||||
);
|
||||
}
|
||||
|
||||
/// Test embed endpoint preserves embedding vector values across calls.
|
||||
#[tokio::test]
|
||||
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
|
||||
async fn test_embed_deterministic() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
let request_body = json!({
|
||||
"texts": ["Deterministic test"]
|
||||
});
|
||||
|
||||
// First call
|
||||
let response1 = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response1.status(), StatusCode::OK);
|
||||
|
||||
let body1 = axum::body::to_bytes(response1.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to convert to bytes");
|
||||
let embed_response1: EmbedResponse = serde_json::from_slice(&body1).expect("Failed to deserialize");
|
||||
|
||||
// Second call with same text
|
||||
let response2 = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_body).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response2.status(), StatusCode::OK);
|
||||
|
||||
let body2 = axum::body::to_bytes(response2.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Failed to convert to bytes");
|
||||
let embed_response2: EmbedResponse = serde_json::from_slice(&body2).expect("Failed to deserialize");
|
||||
|
||||
// Compare embeddings - they should be identical
|
||||
assert_eq!(embed_response1.embeddings.len(), embed_response2.embeddings.len());
|
||||
assert_eq!(embed_response1.embeddings[0], embed_response2.embeddings[0]);
|
||||
}
|
||||
|
||||
/// Test embed endpoint with different embedding presets.
|
||||
#[tokio::test]
|
||||
#[cfg_attr(target_arch = "aarch64", ignore = "ONNX Runtime model loading unstable on ARM")]
|
||||
async fn test_embed_different_presets() {
|
||||
let app = create_router(ExtractionConfig::default());
|
||||
|
||||
// Test with "fast" preset
|
||||
let request_fast = json!({
|
||||
"texts": ["Test text"],
|
||||
"config": {
|
||||
"model": {
|
||||
"type": "preset",
|
||||
"name": "fast"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response_fast = app
|
||||
.clone()
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_fast).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response_fast.status(), StatusCode::OK);
|
||||
|
||||
let body_fast = axum::body::to_bytes(response_fast.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
let embed_fast: EmbedResponse = serde_json::from_slice(&body_fast).expect("Failed to deserialize");
|
||||
|
||||
// Test with "balanced" preset
|
||||
let request_balanced = json!({
|
||||
"texts": ["Test text"],
|
||||
"config": {
|
||||
"model": {
|
||||
"type": "preset",
|
||||
"name": "balanced"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let response_balanced = app
|
||||
.oneshot(
|
||||
Request::builder()
|
||||
.method("POST")
|
||||
.uri("/embed")
|
||||
.header("content-type", "application/json")
|
||||
.body(Body::from(
|
||||
serde_json::to_string(&request_balanced).expect("Operation failed"),
|
||||
))
|
||||
.expect("Operation failed"),
|
||||
)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
|
||||
assert_eq!(response_balanced.status(), StatusCode::OK);
|
||||
|
||||
let body_balanced = axum::body::to_bytes(response_balanced.into_body(), usize::MAX)
|
||||
.await
|
||||
.expect("Operation failed");
|
||||
let embed_balanced: EmbedResponse = serde_json::from_slice(&body_balanced).expect("Failed to deserialize");
|
||||
|
||||
// Different presets should have different dimensions
|
||||
assert_ne!(embed_fast.dimensions, embed_balanced.dimensions);
|
||||
assert_eq!(embed_fast.model, "fast");
|
||||
assert_eq!(embed_balanced.model, "balanced");
|
||||
}
|
||||
Reference in New Issue
Block a user