diff --git a/Cargo.lock b/Cargo.lock index 4f945ac..e289687 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,6 +11,12 @@ dependencies = [ "memchr", ] +[[package]] +name = "base64" +version = "0.21.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d297deb1925b89f2ccc13d7635fa0714f12c87adce1c75356b39ca9b7178567" + [[package]] name = "bindgen" version = "0.69.5" @@ -115,6 +121,12 @@ dependencies = [ "either", ] +[[package]] +name = "itoa" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a5f13b858c8d314ee3e8f639011f7ccefe71f97f96e50151fb991f267928e2c" + [[package]] name = "lazy_static" version = "1.5.0" @@ -159,10 +171,12 @@ checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" name = "marisa-rs" version = "0.1.1" dependencies = [ + "base64", "bindgen", "cc", "libc", "pkg-config", + "serde_json", ] [[package]] @@ -275,6 +289,44 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "ryu" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" + +[[package]] +name = "serde" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.219" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.141" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30b9eff21ebe718216c6ec64e1d9ac57087aad11efc64e32002bce4a0d4c03d3" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + [[package]] name = "shlex" version = "1.3.0" diff --git a/Cargo.toml b/Cargo.toml index ea69eba..0376d9f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,6 +16,8 @@ crate-type = ["cdylib", "rlib"] [dependencies] libc = "0.2" +serde_json = "1.0" +base64 = "0.21" [build-dependencies] cc = "1.0" diff --git a/README.md b/README.md index f2065fa..4212297 100644 --- a/README.md +++ b/README.md @@ -90,6 +90,47 @@ mmapped_trie.mmap("my_trie.marisa")?; println!("Trie size: {} bytes", trie.io_size()); ``` +### RecordTrie Usage + +RecordTrie allows storing multiple structured records for each key: + +```rust +use marisa_rs::RecordTrie; + +// Create a builder +let mut builder = RecordTrie::builder(); + +// Add structured data (supports duplicate keys) +builder.insert_u32_pair("apple", (1, 100)); // price, quantity +builder.insert_u32_pair("apple", (2, 50)); // different record for same key +builder.insert_u32_pair("banana", (3, 200)); + +// Add vector data +builder.insert_u32_vec("features", vec![1, 2, 3, 4]); +builder.insert_u32_vec("features", vec![5, 6]); // different length + +// Add raw binary data +builder.insert("description", b"Fresh fruit".to_vec()); + +// Build the RecordTrie +let record_trie = builder.build().unwrap(); + +// Retrieve all records for a key +let apple_records = record_trie.get_u32_pairs("apple"); +println!("Apple records: {:?}", apple_records); // [(1, 100), (2, 50)] + +let feature_vecs = record_trie.get_u32_vecs("features"); +println!("Feature vectors: {:?}", feature_vecs); // [[1, 2, 3, 4], [5, 6]] + +// Prefix search works too +let fruit_keys = record_trie.keys_with_prefix("a"); +println!("Keys starting with 'a': {:?}", fruit_keys); // ["apple"] + +// Save and load +record_trie.save("trie.marisa", "records.json").unwrap(); +let loaded_trie = RecordTrie::load("trie.marisa", "records.json").unwrap(); +``` + ### Lookup Operations ```rust @@ -162,6 +203,24 @@ trie.build(&mut keyset)?; - `trie.io_size()` - Get the serialized size of the trie in bytes - `trie.clear()` - Clear the trie, removing all keys (returns `Result<(), &str>`) +### RecordTrie + +RecordTrie allows storing structured data associated with keys, similar to Python's marisa-trie RecordTrie: + +- `RecordTrie::builder()` - Create a new RecordTrie builder +- `builder.insert(key, data)` - Insert raw binary data for a key +- `builder.insert_u32_pair(key, (a, b))` - Insert a pair of u32 values +- `builder.insert_u32_vec(key, vec![...])` - Insert a vector of u32 values +- `builder.build()` - Build the RecordTrie +- `trie.get(key)` - Get all raw binary records for a key +- `trie.get_u32_pairs(key)` - Get all u32 pairs for a key +- `trie.get_u32_vecs(key)` - Get all u32 vectors for a key +- `trie.contains_key(key)` - Check if a key exists +- `trie.keys_with_prefix(prefix)` - Find all keys with given prefix +- `trie.prefixes_of(query)` - Find all keys that are prefixes of query +- `trie.save(trie_path, records_path)` - Save to files +- `RecordTrie::load(trie_path, records_path)` - Load from files + ## Japanese Text Example ```rust diff --git a/src/lib.rs b/src/lib.rs index 24bc247..d3cc1f9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -31,6 +31,7 @@ use std::slice; use std::path::Path; use std::ffi::CString; +use std::collections::HashMap; mod bindings { include!(concat!(env!("OUT_DIR"), "/bindings.rs")); @@ -387,6 +388,264 @@ unsafe impl Send for Keyset {} unsafe impl Send for Trie {} unsafe impl Send for Agent {} +/// A RecordTrie for storing structured data associated with keys. +/// +/// RecordTrie allows storing multiple structured records for each key, +/// similar to Python's marisa-trie RecordTrie functionality. +/// +/// # Example +/// +/// ```rust +/// use marisa_rs::RecordTrie; +/// +/// let mut record_trie = RecordTrie::new(); +/// record_trie.insert("apple", vec![1u32, 2u32]); +/// record_trie.insert("apple", vec![3u32, 4u32]); // Duplicate key +/// record_trie.insert("banana", vec![5u32, 6u32]); +/// +/// let trie = record_trie.build().unwrap(); +/// +/// // Lookup returns all records for a key +/// let apple_records = trie.get("apple"); +/// assert_eq!(apple_records.len(), 2); // Two records for "apple" +/// ``` +pub struct RecordTrie { + /// Internal trie for key management + trie: Trie, + /// Storage for record data, keyed by trie ID + records: HashMap>>, +} + +/// Builder for creating RecordTrie instances +pub struct RecordTrieBuilder { + /// Key-value pairs to be built into the trie + entries: HashMap>>, +} + +impl RecordTrieBuilder { + /// Creates a new RecordTrieBuilder + pub fn new() -> Self { + RecordTrieBuilder { + entries: HashMap::new(), + } + } + + /// Inserts a record for a given key + /// + /// Multiple records can be inserted for the same key. + /// + /// # Arguments + /// + /// * `key` - The string key + /// * `data` - Binary data as Vec + pub fn insert>(&mut self, key: K, data: Vec) { + let key_str = key.as_ref().to_string(); + self.entries.entry(key_str).or_insert_with(Vec::new).push(data); + } + + /// Inserts a record with automatic serialization for common types + /// + /// # Arguments + /// + /// * `key` - The string key + /// * `data` - Data that can be serialized to bytes + pub fn insert_u32_pair>(&mut self, key: K, data: (u32, u32)) { + let mut bytes = Vec::new(); + bytes.extend_from_slice(&data.0.to_le_bytes()); + bytes.extend_from_slice(&data.1.to_le_bytes()); + self.insert(key, bytes); + } + + /// Inserts multiple u32 values as a record + pub fn insert_u32_vec>(&mut self, key: K, data: Vec) { + let mut bytes = Vec::new(); + for value in data { + bytes.extend_from_slice(&value.to_le_bytes()); + } + self.insert(key, bytes); + } + + /// Builds the RecordTrie + /// + /// # Errors + /// + /// Returns an error if the trie cannot be built. + pub fn build(self) -> Result { + // Create keyset with all unique keys + let mut keyset = Keyset::new(); + let mut key_order = Vec::new(); + + for key in self.entries.keys() { + keyset.push(key); + key_order.push(key.clone()); + } + + // Build the internal trie + let mut trie = Trie::new(); + trie.build(&mut keyset)?; + + // Create records mapping using trie IDs + let mut records = HashMap::new(); + + for key in key_order { + if let Some(id) = trie.lookup(&key) { + if let Some(data_list) = self.entries.get(&key) { + records.insert(id, data_list.clone()); + } + } + } + + Ok(RecordTrie { trie, records }) + } +} + +impl RecordTrie { + /// Creates a new RecordTrieBuilder for constructing RecordTrie instances + pub fn builder() -> RecordTrieBuilder { + RecordTrieBuilder::new() + } + + /// Gets all records associated with a key + /// + /// # Arguments + /// + /// * `key` - The key to look up + /// + /// # Returns + /// + /// A vector of all records associated with the key, or an empty vector if the key is not found. + pub fn get>(&self, key: K) -> Vec<&Vec> { + if let Some(id) = self.trie.lookup(key.as_ref()) { + if let Some(record_list) = self.records.get(&id) { + return record_list.iter().collect(); + } + } + Vec::new() + } + + /// Gets records and deserializes them as u32 pairs + pub fn get_u32_pairs>(&self, key: K) -> Vec<(u32, u32)> { + self.get(key).into_iter() + .filter_map(|bytes| { + if bytes.len() >= 8 { + let first = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); + let second = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); + Some((first, second)) + } else { + None + } + }) + .collect() + } + + /// Gets records and deserializes them as u32 vectors + pub fn get_u32_vecs>(&self, key: K) -> Vec> { + self.get(key).into_iter() + .map(|bytes| { + bytes.chunks_exact(4) + .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) + .collect() + }) + .collect() + } + + /// Checks if a key exists in the trie + pub fn contains_key>(&self, key: K) -> bool { + self.trie.lookup(key.as_ref()).is_some() + } + + /// Returns the number of unique keys in the trie + pub fn len(&self) -> usize { + self.trie.size() + } + + /// Returns true if the trie is empty + pub fn is_empty(&self) -> bool { + self.trie.is_empty() + } + + /// Gets all keys that have the given string as a prefix + pub fn keys_with_prefix>(&self, prefix: K) -> Vec { + let mut results = Vec::new(); + self.trie.predictive_search(prefix.as_ref(), |key, _id| { + results.push(key.to_string()); + }); + results + } + + /// Gets all keys that are prefixes of the given string + pub fn prefixes_of>(&self, query: K) -> Vec { + let mut results = Vec::new(); + self.trie.common_prefix_search(query.as_ref(), |key, _id| { + results.push(key.to_string()); + }); + results + } + + /// Saves the RecordTrie to files + /// + /// # Arguments + /// + /// * `trie_path` - Path for the trie structure file + /// * `records_path` - Path for the records data file + pub fn save>(&self, trie_path: P, records_path: P) -> Result<(), Box> { + // Save the internal trie + self.trie.save(&trie_path).map_err(|e| e.to_string())?; + + // Convert records to a JSON-serializable format + let mut json_records = std::collections::HashMap::new(); + for (id, record_list) in &self.records { + let encoded_records: Vec = record_list.iter() + .map(|bytes| base64::encode_engine(bytes, &base64::engine::general_purpose::STANDARD)) + .collect(); + json_records.insert(id.to_string(), encoded_records); + } + + // Serialize and save records as JSON + let json_data = serde_json::to_string(&json_records)?; + std::fs::write(records_path, json_data)?; + + Ok(()) + } + + /// Loads a RecordTrie from files + /// + /// # Arguments + /// + /// * `trie_path` - Path to the trie structure file + /// * `records_path` - Path to the records data file + pub fn load>(trie_path: P, records_path: P) -> Result> { + // Load the internal trie + let mut trie = Trie::new(); + trie.load(trie_path).map_err(|e| e.to_string())?; + + // Load and deserialize records from JSON + let json_data = std::fs::read_to_string(records_path)?; + let json_records: std::collections::HashMap> = serde_json::from_str(&json_data)?; + + // Convert back to internal format + let mut records = HashMap::new(); + for (id_str, encoded_records) in json_records { + let id: usize = id_str.parse()?; + let decoded_records: Result>, _> = encoded_records.iter() + .map(|encoded| base64::decode_engine(encoded, &base64::engine::general_purpose::STANDARD)) + .collect(); + records.insert(id, decoded_records?); + } + + Ok(RecordTrie { trie, records }) + } +} + +impl Default for RecordTrieBuilder { + fn default() -> Self { + Self::new() + } +} + +unsafe impl Send for RecordTrie {} +unsafe impl Send for RecordTrieBuilder {} + #[cfg(test)] mod tests { use super::*; @@ -540,4 +799,112 @@ mod tests { assert!(trie.is_empty()); assert!(trie.lookup("test").is_none()); } + + #[test] + fn test_record_trie_basic() { + let mut builder = RecordTrie::builder(); + + // Add records with duplicate keys + builder.insert_u32_pair("apple", (1, 2)); + builder.insert_u32_pair("apple", (3, 4)); // Duplicate key + builder.insert_u32_pair("banana", (5, 6)); + + let trie = builder.build().expect("Failed to build RecordTrie"); + + // Test basic functionality + assert_eq!(trie.len(), 2); // Two unique keys + assert!(!trie.is_empty()); + assert!(trie.contains_key("apple")); + assert!(trie.contains_key("banana")); + assert!(!trie.contains_key("orange")); + + // Test retrieving records + let apple_records = trie.get_u32_pairs("apple"); + assert_eq!(apple_records.len(), 2); + assert!(apple_records.contains(&(1, 2))); + assert!(apple_records.contains(&(3, 4))); + + let banana_records = trie.get_u32_pairs("banana"); + assert_eq!(banana_records.len(), 1); + assert_eq!(banana_records[0], (5, 6)); + + // Test non-existent key + let orange_records = trie.get_u32_pairs("orange"); + assert_eq!(orange_records.len(), 0); + } + + #[test] + fn test_record_trie_u32_vectors() { + let mut builder = RecordTrie::builder(); + + builder.insert_u32_vec("numbers", vec![1, 2, 3, 4]); + builder.insert_u32_vec("numbers", vec![5, 6]); // Different length + builder.insert_u32_vec("single", vec![42]); + + let trie = builder.build().expect("Failed to build RecordTrie"); + + let number_records = trie.get_u32_vecs("numbers"); + assert_eq!(number_records.len(), 2); + assert!(number_records.contains(&vec![1, 2, 3, 4])); + assert!(number_records.contains(&vec![5, 6])); + + let single_records = trie.get_u32_vecs("single"); + assert_eq!(single_records.len(), 1); + assert_eq!(single_records[0], vec![42]); + } + + #[test] + fn test_record_trie_raw_bytes() { + let mut builder = RecordTrie::builder(); + + builder.insert("text", b"hello".to_vec()); + builder.insert("text", b"world".to_vec()); + builder.insert("binary", vec![0xFF, 0x00, 0xAA, 0x55]); + + let trie = builder.build().expect("Failed to build RecordTrie"); + + let text_records = trie.get("text"); + assert_eq!(text_records.len(), 2); + assert!(text_records.iter().any(|&bytes| bytes == b"hello")); + assert!(text_records.iter().any(|&bytes| bytes == b"world")); + + let binary_records = trie.get("binary"); + assert_eq!(binary_records.len(), 1); + assert_eq!(binary_records[0], &vec![0xFF, 0x00, 0xAA, 0x55]); + } + + #[test] + fn test_record_trie_prefix_search() { + let mut builder = RecordTrie::builder(); + + builder.insert_u32_pair("app", (1, 1)); + builder.insert_u32_pair("apple", (2, 2)); + builder.insert_u32_pair("application", (3, 3)); + builder.insert_u32_pair("banana", (4, 4)); + + let trie = builder.build().expect("Failed to build RecordTrie"); + + // Test keys with prefix + let app_keys = trie.keys_with_prefix("app"); + assert_eq!(app_keys.len(), 3); + assert!(app_keys.contains(&"app".to_string())); + assert!(app_keys.contains(&"apple".to_string())); + assert!(app_keys.contains(&"application".to_string())); + + // Test prefixes of a key + let prefixes = trie.prefixes_of("application"); + assert!(prefixes.contains(&"app".to_string())); + assert!(prefixes.contains(&"application".to_string())); + } + + #[test] + fn test_record_trie_empty() { + let builder = RecordTrie::builder(); + let trie = builder.build().expect("Failed to build empty RecordTrie"); + + assert_eq!(trie.len(), 0); + assert!(trie.is_empty()); + assert!(!trie.contains_key("anything")); + assert_eq!(trie.get("anything").len(), 0); + } } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 7da8725..f6217ad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,4 +1,4 @@ -use marisa_rs::{Keyset, Trie}; +use marisa_rs::{Keyset, Trie, RecordTrie}; fn main() { println!("Marisa Trie Demo"); @@ -49,6 +49,7 @@ fn main() { }); test_save_load_functionality(&trie); + test_record_trie_functionality(); } fn test_save_load_functionality(original_trie: &Trie) { @@ -168,3 +169,139 @@ fn test_save_load_functionality(original_trie: &Trie) { Err(e) => println!("\n⚠ Warning: Could not remove test file '{}': {}", test_file, e), } } + +fn test_record_trie_functionality() { + println!("\n=== RecordTrie Functionality Test ==="); + + // Create a RecordTrie with duplicate keys and different data types + let mut builder = RecordTrie::builder(); + + println!("Building RecordTrie with structured data..."); + + // Add some records with duplicate keys + builder.insert_u32_pair("apple", (1, 100)); + builder.insert_u32_pair("apple", (2, 200)); // Duplicate key + builder.insert_u32_pair("banana", (3, 300)); + builder.insert_u32_pair("cherry", (4, 400)); + + // Add vector data + builder.insert_u32_vec("numbers", vec![1, 2, 3, 4, 5]); + builder.insert_u32_vec("numbers", vec![10, 20]); // Duplicate key, different data + + // Add raw binary data + builder.insert("text", b"Hello, World!".to_vec()); + builder.insert("text", b"Rust is awesome!".to_vec()); + + match builder.build() { + Ok(record_trie) => { + println!("✓ RecordTrie built successfully! Keys: {}", record_trie.len()); + + // Test basic lookup + println!("\n--- Testing Basic Lookup ---"); + test_record_lookup(&record_trie); + + // Test prefix searches + println!("\n--- Testing Prefix Searches ---"); + test_record_prefix_search(&record_trie); + + // Test save/load functionality + println!("\n--- Testing RecordTrie Save/Load ---"); + test_record_save_load(&record_trie); + + } + Err(e) => { + println!("✗ Failed to build RecordTrie: {}", e); + } + } +} + +fn test_record_lookup(record_trie: &RecordTrie) { + // Test u32 pairs + let apple_records = record_trie.get_u32_pairs("apple"); + println!("'apple' records: {} found", apple_records.len()); + for (i, (first, second)) in apple_records.iter().enumerate() { + println!(" Record {}: ({}, {})", i + 1, first, second); + } + + let banana_records = record_trie.get_u32_pairs("banana"); + println!("'banana' records: {} found", banana_records.len()); + for (i, (first, second)) in banana_records.iter().enumerate() { + println!(" Record {}: ({}, {})", i + 1, first, second); + } + + // Test u32 vectors + let number_records = record_trie.get_u32_vecs("numbers"); + println!("'numbers' records: {} found", number_records.len()); + for (i, vec) in number_records.iter().enumerate() { + println!(" Record {}: {:?}", i + 1, vec); + } + + // Test raw bytes + let text_records = record_trie.get("text"); + println!("'text' records: {} found", text_records.len()); + for (i, bytes) in text_records.iter().enumerate() { + if let Ok(text) = std::str::from_utf8(bytes) { + println!(" Record {}: \"{}\"", i + 1, text); + } else { + println!(" Record {}: {:?} (binary)", i + 1, bytes); + } + } + + // Test non-existent key + let missing_records = record_trie.get_u32_pairs("missing"); + println!("'missing' records: {} found", missing_records.len()); +} + +fn test_record_prefix_search(record_trie: &RecordTrie) { + // Test prefix search + let keys_with_prefix = record_trie.keys_with_prefix("a"); + println!("Keys starting with 'a': {:?}", keys_with_prefix); + + // Test finding prefixes + let prefixes = record_trie.prefixes_of("application"); + println!("Prefixes of 'application': {:?}", prefixes); +} + +fn test_record_save_load(record_trie: &RecordTrie) { + let trie_file = "test_record_trie.marisa"; + let records_file = "test_record_data.json"; + + // Clean up from any previous test + let _ = std::fs::remove_file(trie_file); + let _ = std::fs::remove_file(records_file); + + // Test save + match record_trie.save(trie_file, records_file) { + Ok(()) => println!("✓ RecordTrie saved successfully"), + Err(e) => { + println!("✗ Failed to save RecordTrie: {}", e); + return; + } + } + + // Test load + match RecordTrie::load(trie_file, records_file) { + Ok(loaded_trie) => { + println!("✓ RecordTrie loaded successfully! Keys: {}", loaded_trie.len()); + + // Verify the loaded trie works + let apple_records = loaded_trie.get_u32_pairs("apple"); + println!(" Verification: 'apple' has {} records after load", apple_records.len()); + + if apple_records.len() == 2 { + println!("✓ RecordTrie save/load verification passed!"); + } else { + println!("⚠ RecordTrie save/load verification failed"); + } + } + Err(e) => { + println!("✗ Failed to load RecordTrie: {}", e); + } + } + + // Clean up + match (std::fs::remove_file(trie_file), std::fs::remove_file(records_file)) { + (Ok(()), Ok(())) => println!("✓ Cleanup: Removed test files"), + _ => println!("⚠ Warning: Could not remove some test files"), + } +}