//! # marisa-rs //! //! Safe Rust wrapper for the marisa-trie C++ library. //! //! marisa-trie is a static and space-efficient trie data structure library. //! This crate provides safe Rust bindings to the C++ library. //! //! ## Example //! //! ```rust //! use marisa_rs::{Keyset, Trie}; //! //! let mut keyset = Keyset::new(); //! keyset.push("apple"); //! keyset.push("application"); //! keyset.push("apply"); //! //! let mut trie = Trie::new(); //! trie.build(&mut keyset).unwrap(); //! //! // Lookup //! assert!(trie.lookup("apple").is_some()); //! assert!(trie.lookup("orange").is_none()); //! //! // Common prefix search //! trie.common_prefix_search("application", |key, id| { //! println!("Found: {} (ID: {})", key, id); //! }); //! ``` use std::slice; use std::path::Path; use std::ffi::CString; use std::collections::HashMap; mod bindings { include!(concat!(env!("OUT_DIR"), "/bindings.rs")); } use bindings::*; /// A keyset for building a trie. /// /// Keyset is used to store a collection of keys before building a trie. /// Keys can be added with different weights. pub struct Keyset { inner: *mut MarisaKeyset, } impl Keyset { /// Creates a new empty keyset. pub fn new() -> Self { unsafe { let inner = marisa_keyset_new(); Keyset { inner } } } /// Adds a key with the specified weight to the keyset. pub fn push_back(&mut self, key: &str, weight: f32) { let key_bytes = key.as_bytes(); unsafe { marisa_keyset_push_back( self.inner, key_bytes.as_ptr() as *const i8, key_bytes.len(), weight, ); } } /// Adds a key with default weight (1.0) to the keyset. pub fn push(&mut self, key: &str) { self.push_back(key, 1.0); } /// Returns the number of keys in the keyset. pub fn size(&self) -> usize { unsafe { marisa_keyset_size(self.inner) } } /// Returns true if the keyset is empty. pub fn is_empty(&self) -> bool { self.size() == 0 } } impl Drop for Keyset { fn drop(&mut self) { unsafe { marisa_keyset_delete(self.inner); } } } /// A trie data structure for efficient string lookups. /// /// The trie must be built from a keyset before it can be used for lookups. pub struct Trie { inner: *mut MarisaTrie, } impl Trie { /// Creates a new empty trie. pub fn new() -> Self { unsafe { let inner = marisa_trie_new(); Trie { inner } } } /// Builds the trie from the given keyset. /// /// # Errors /// /// Returns an error if the trie cannot be built from the keyset. pub fn build(&mut self, keyset: &mut Keyset) -> Result<(), &'static str> { unsafe { if marisa_trie_build(self.inner, keyset.inner) == 1 { Ok(()) } else { Err("Failed to build trie") } } } /// Looks up a key in the trie and returns its ID if found. /// /// # Returns /// /// - `Some(id)` if the key is found in the trie /// - `None` if the key is not found pub fn lookup(&self, key: &str) -> Option { let mut agent = Agent::new(); agent.set_query(key); unsafe { if marisa_trie_lookup(self.inner, agent.inner) == 1 { Some(agent.key_id()) } else { None } } } /// Performs reverse lookup to get the key corresponding to the given ID. /// /// # Errors /// /// Returns an error if the ID is not valid. pub fn reverse_lookup(&self, id: usize) -> Result { let mut agent = Agent::new(); agent.set_query_by_id(id); unsafe { if marisa_trie_reverse_lookup(self.inner, agent.inner) == 1 { Ok(agent.key_string()) } else { Err("Failed to reverse lookup") } } } /// Searches for all keys that are prefixes of the given query. /// /// The callback function is called for each matching key with the key and its ID. pub fn common_prefix_search(&self, query: &str, mut callback: F) where F: FnMut(&str, usize), { let mut agent = Agent::new(); agent.set_query(query); unsafe { while marisa_trie_common_prefix_search(self.inner, agent.inner) == 1 { let key = agent.key_string(); let id = agent.key_id(); callback(&key, id); } } } /// Searches for all keys that have the given query as a prefix. /// /// The callback function is called for each matching key with the key and its ID. pub fn predictive_search(&self, query: &str, mut callback: F) where F: FnMut(&str, usize), { let mut agent = Agent::new(); agent.set_query(query); unsafe { while marisa_trie_predictive_search(self.inner, agent.inner) == 1 { let key = agent.key_string(); let id = agent.key_id(); callback(&key, id); } } } /// Returns the number of keys stored in the trie. pub fn size(&self) -> usize { unsafe { let mut size: usize = 0; if marisa_trie_size(self.inner, &mut size) == 1 { size } else { 0 // Return 0 if there's an error } } } /// Returns true if the trie is empty. pub fn is_empty(&self) -> bool { self.size() == 0 } /// Saves the trie to a file. /// /// # Arguments /// /// * `path` - The file path to save the trie to /// /// # Errors /// /// Returns an error if the trie cannot be saved to the file. pub fn save>(&self, path: P) -> Result<(), &'static str> { let path_str = path.as_ref().to_string_lossy(); let c_path = CString::new(path_str.as_ref()) .map_err(|_| "Invalid path: contains null bytes")?; unsafe { if marisa_trie_save(self.inner, c_path.as_ptr()) == 1 { Ok(()) } else { Err("Failed to save trie") } } } /// Loads a trie from a file. /// /// # Arguments /// /// * `path` - The file path to load the trie from /// /// # Errors /// /// Returns an error if the trie cannot be loaded from the file. pub fn load>(&mut self, path: P) -> Result<(), &'static str> { let path_str = path.as_ref().to_string_lossy(); let c_path = CString::new(path_str.as_ref()) .map_err(|_| "Invalid path: contains null bytes")?; unsafe { if marisa_trie_load(self.inner, c_path.as_ptr()) == 1 { Ok(()) } else { Err("Failed to load trie") } } } /// Memory-maps a trie file for read-only access. /// /// This is more memory-efficient than loading the entire trie into memory. /// /// # Arguments /// /// * `path` - The file path to memory-map the trie from /// /// # Errors /// /// Returns an error if the trie cannot be memory-mapped from the file. pub fn mmap>(&mut self, path: P) -> Result<(), &'static str> { let path_str = path.as_ref().to_string_lossy(); let c_path = CString::new(path_str.as_ref()) .map_err(|_| "Invalid path: contains null bytes")?; unsafe { if marisa_trie_mmap(self.inner, c_path.as_ptr()) == 1 { Ok(()) } else { Err("Failed to memory-map trie") } } } /// Returns the size of the trie when serialized. /// /// This can be useful for determining the storage requirements /// before saving the trie to a file. pub fn io_size(&self) -> usize { unsafe { let mut size: usize = 0; if marisa_trie_io_size(self.inner, &mut size) == 1 { size } else { 0 // Return 0 if there's an error } } } /// Clears the trie, removing all keys. /// /// After calling this method, the trie will be empty and /// must be rebuilt from a keyset before it can be used again. /// /// # Errors /// /// Returns an error if the trie cannot be cleared. pub fn clear(&mut self) -> Result<(), &'static str> { unsafe { if marisa_trie_clear(self.inner) == 1 { Ok(()) } else { Err("Failed to clear trie") } } } } impl Drop for Trie { fn drop(&mut self) { unsafe { marisa_trie_delete(self.inner); } } } /// An agent for performing trie operations. /// /// Agent is used internally for trie operations and should not be used directly /// in most cases. pub struct Agent { inner: *mut MarisaAgent, } impl Agent { pub fn new() -> Self { unsafe { let inner = marisa_agent_new(); Agent { inner } } } pub fn set_query(&mut self, query: &str) { let query_bytes = query.as_bytes(); unsafe { marisa_agent_set_query( self.inner, query_bytes.as_ptr() as *const i8, query_bytes.len(), ); } } pub fn set_query_by_id(&mut self, id: usize) { unsafe { marisa_agent_set_query_by_id(self.inner, id); } } pub fn key_string(&self) -> String { unsafe { let ptr = marisa_agent_key_ptr(self.inner); let len = marisa_agent_key_length(self.inner); let slice = slice::from_raw_parts(ptr as *const u8, len); String::from_utf8_lossy(slice).into_owned() } } pub fn key_id(&self) -> usize { unsafe { marisa_agent_key_id(self.inner) } } } impl Drop for Agent { fn drop(&mut self) { unsafe { marisa_agent_delete(self.inner); } } } unsafe impl Send for Keyset {} unsafe impl Send for Trie {} unsafe impl Send for Agent {} /// A RecordTrie for storing structured data associated with keys. /// /// RecordTrie allows storing multiple structured records for each key, /// similar to Python's marisa-trie RecordTrie functionality. /// /// # Example /// /// ```rust /// use marisa_rs::RecordTrie; /// /// let mut record_trie = RecordTrie::new(); /// record_trie.insert("apple", vec![1u32, 2u32]); /// record_trie.insert("apple", vec![3u32, 4u32]); // Duplicate key /// record_trie.insert("banana", vec![5u32, 6u32]); /// /// let trie = record_trie.build().unwrap(); /// /// // Lookup returns all records for a key /// let apple_records = trie.get("apple"); /// assert_eq!(apple_records.len(), 2); // Two records for "apple" /// ``` pub struct RecordTrie { /// Internal trie for key management trie: Trie, /// Storage for record data, keyed by trie ID records: HashMap>>, } /// Builder for creating RecordTrie instances pub struct RecordTrieBuilder { /// Key-value pairs to be built into the trie entries: HashMap>>, } impl RecordTrieBuilder { /// Creates a new RecordTrieBuilder pub fn new() -> Self { RecordTrieBuilder { entries: HashMap::new(), } } /// Inserts a record for a given key /// /// Multiple records can be inserted for the same key. /// /// # Arguments /// /// * `key` - The string key /// * `data` - Binary data as Vec pub fn insert>(&mut self, key: K, data: Vec) { let key_str = key.as_ref().to_string(); self.entries.entry(key_str).or_insert_with(Vec::new).push(data); } /// Inserts a record with automatic serialization for common types /// /// # Arguments /// /// * `key` - The string key /// * `data` - Data that can be serialized to bytes pub fn insert_u32_pair>(&mut self, key: K, data: (u32, u32)) { let mut bytes = Vec::new(); bytes.extend_from_slice(&data.0.to_le_bytes()); bytes.extend_from_slice(&data.1.to_le_bytes()); self.insert(key, bytes); } /// Inserts multiple u32 values as a record pub fn insert_u32_vec>(&mut self, key: K, data: Vec) { let mut bytes = Vec::new(); for value in data { bytes.extend_from_slice(&value.to_le_bytes()); } self.insert(key, bytes); } /// Builds the RecordTrie /// /// # Errors /// /// Returns an error if the trie cannot be built. pub fn build(self) -> Result { // Create keyset with all unique keys let mut keyset = Keyset::new(); let mut key_order = Vec::new(); for key in self.entries.keys() { keyset.push(key); key_order.push(key.clone()); } // Build the internal trie let mut trie = Trie::new(); trie.build(&mut keyset)?; // Create records mapping using trie IDs let mut records = HashMap::new(); for key in key_order { if let Some(id) = trie.lookup(&key) { if let Some(data_list) = self.entries.get(&key) { records.insert(id, data_list.clone()); } } } Ok(RecordTrie { trie, records }) } } impl RecordTrie { /// Creates a new RecordTrieBuilder for constructing RecordTrie instances pub fn builder() -> RecordTrieBuilder { RecordTrieBuilder::new() } /// Gets all records associated with a key /// /// # Arguments /// /// * `key` - The key to look up /// /// # Returns /// /// A vector of all records associated with the key, or an empty vector if the key is not found. pub fn get>(&self, key: K) -> Vec<&Vec> { if let Some(id) = self.trie.lookup(key.as_ref()) { if let Some(record_list) = self.records.get(&id) { return record_list.iter().collect(); } } Vec::new() } /// Gets records and deserializes them as u32 pairs pub fn get_u32_pairs>(&self, key: K) -> Vec<(u32, u32)> { self.get(key).into_iter() .filter_map(|bytes| { if bytes.len() >= 8 { let first = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); let second = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]); Some((first, second)) } else { None } }) .collect() } /// Gets records and deserializes them as u32 vectors pub fn get_u32_vecs>(&self, key: K) -> Vec> { self.get(key).into_iter() .map(|bytes| { bytes.chunks_exact(4) .map(|chunk| u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])) .collect() }) .collect() } /// Checks if a key exists in the trie pub fn contains_key>(&self, key: K) -> bool { self.trie.lookup(key.as_ref()).is_some() } /// Returns the number of unique keys in the trie pub fn len(&self) -> usize { self.trie.size() } /// Returns true if the trie is empty pub fn is_empty(&self) -> bool { self.trie.is_empty() } /// Gets all keys that have the given string as a prefix pub fn keys_with_prefix>(&self, prefix: K) -> Vec { let mut results = Vec::new(); self.trie.predictive_search(prefix.as_ref(), |key, _id| { results.push(key.to_string()); }); results } /// Gets all keys that are prefixes of the given string pub fn prefixes_of>(&self, query: K) -> Vec { let mut results = Vec::new(); self.trie.common_prefix_search(query.as_ref(), |key, _id| { results.push(key.to_string()); }); results } /// Saves the RecordTrie to files /// /// # Arguments /// /// * `trie_path` - Path for the trie structure file /// * `records_path` - Path for the records data file pub fn save>(&self, trie_path: P, records_path: P) -> Result<(), Box> { // Save the internal trie self.trie.save(&trie_path).map_err(|e| e.to_string())?; // Convert records to a JSON-serializable format let mut json_records = std::collections::HashMap::new(); for (id, record_list) in &self.records { let encoded_records: Vec = record_list.iter() .map(|bytes| base64::encode_engine(bytes, &base64::engine::general_purpose::STANDARD)) .collect(); json_records.insert(id.to_string(), encoded_records); } // Serialize and save records as JSON let json_data = serde_json::to_string(&json_records)?; std::fs::write(records_path, json_data)?; Ok(()) } /// Loads a RecordTrie from files /// /// # Arguments /// /// * `trie_path` - Path to the trie structure file /// * `records_path` - Path to the records data file pub fn load>(trie_path: P, records_path: P) -> Result> { // Load the internal trie let mut trie = Trie::new(); trie.load(trie_path).map_err(|e| e.to_string())?; // Load and deserialize records from JSON let json_data = std::fs::read_to_string(records_path)?; let json_records: std::collections::HashMap> = serde_json::from_str(&json_data)?; // Convert back to internal format let mut records = HashMap::new(); for (id_str, encoded_records) in json_records { let id: usize = id_str.parse()?; let decoded_records: Result>, _> = encoded_records.iter() .map(|encoded| base64::decode_engine(encoded, &base64::engine::general_purpose::STANDARD)) .collect(); records.insert(id, decoded_records?); } Ok(RecordTrie { trie, records }) } } impl Default for RecordTrieBuilder { fn default() -> Self { Self::new() } } unsafe impl Send for RecordTrie {} unsafe impl Send for RecordTrieBuilder {} #[cfg(test)] mod tests { use super::*; #[test] fn test_basic_operations() { let mut keyset = Keyset::new(); keyset.push("apple"); keyset.push("application"); keyset.push("apply"); keyset.push("apricot"); assert_eq!(keyset.size(), 4); let mut trie = Trie::new(); trie.build(&mut keyset).expect("Failed to build trie"); assert_eq!(trie.size(), 4); // Test lookup assert!(trie.lookup("apple").is_some()); assert!(trie.lookup("banana").is_none()); // Test reverse lookup if let Some(id) = trie.lookup("apple") { assert_eq!(trie.reverse_lookup(id).unwrap(), "apple"); } // Test common prefix search let mut results = Vec::new(); trie.common_prefix_search("application", |key, id| { results.push((key.to_string(), id)); }); assert!(results.len() > 0); // Test predictive search let mut results = Vec::new(); trie.predictive_search("app", |key, id| { results.push((key.to_string(), id)); }); assert!(results.len() > 0); } #[test] fn test_empty_keyset() { let keyset = Keyset::new(); assert!(keyset.is_empty()); assert_eq!(keyset.size(), 0); } #[test] fn test_empty_trie() { let trie = Trie::new(); assert!(trie.is_empty()); assert_eq!(trie.size(), 0); } #[test] fn test_save_and_load() { use std::fs; let test_file = "test_trie.marisa"; // Clean up from any previous test let _ = fs::remove_file(test_file); // Build a trie let mut keyset = Keyset::new(); keyset.push("apple"); keyset.push("application"); keyset.push("apply"); let mut original_trie = Trie::new(); original_trie.build(&mut keyset).expect("Failed to build trie"); // Test io_size assert!(original_trie.io_size() > 0); // Save the trie original_trie.save(test_file).expect("Failed to save trie"); // Load the trie let mut loaded_trie = Trie::new(); loaded_trie.load(test_file).expect("Failed to load trie"); // Verify the loaded trie works the same assert_eq!(loaded_trie.size(), original_trie.size()); assert!(loaded_trie.lookup("apple").is_some()); assert!(loaded_trie.lookup("application").is_some()); assert!(loaded_trie.lookup("apply").is_some()); assert!(loaded_trie.lookup("banana").is_none()); // Test reverse lookup if let Some(id) = loaded_trie.lookup("apple") { assert_eq!(loaded_trie.reverse_lookup(id).unwrap(), "apple"); } // Clean up let _ = fs::remove_file(test_file); } #[test] fn test_mmap() { use std::fs; let test_file = "test_mmap_trie.marisa"; // Clean up from any previous test let _ = fs::remove_file(test_file); // Build and save a trie let mut keyset = Keyset::new(); keyset.push("memory"); keyset.push("mapped"); keyset.push("trie"); let mut original_trie = Trie::new(); original_trie.build(&mut keyset).expect("Failed to build trie"); original_trie.save(test_file).expect("Failed to save trie"); // Memory-map the trie let mut mmapped_trie = Trie::new(); mmapped_trie.mmap(test_file).expect("Failed to mmap trie"); // Verify the memory-mapped trie works assert_eq!(mmapped_trie.size(), 3); assert!(mmapped_trie.lookup("memory").is_some()); assert!(mmapped_trie.lookup("mapped").is_some()); assert!(mmapped_trie.lookup("trie").is_some()); assert!(mmapped_trie.lookup("nonexistent").is_none()); // Clean up let _ = fs::remove_file(test_file); } #[test] fn test_clear() { let mut keyset = Keyset::new(); keyset.push("test"); let mut trie = Trie::new(); trie.build(&mut keyset).expect("Failed to build trie"); assert_eq!(trie.size(), 1); assert!(!trie.is_empty()); // Clear the trie trie.clear().expect("Failed to clear trie"); assert_eq!(trie.size(), 0); assert!(trie.is_empty()); assert!(trie.lookup("test").is_none()); } #[test] fn test_record_trie_basic() { let mut builder = RecordTrie::builder(); // Add records with duplicate keys builder.insert_u32_pair("apple", (1, 2)); builder.insert_u32_pair("apple", (3, 4)); // Duplicate key builder.insert_u32_pair("banana", (5, 6)); let trie = builder.build().expect("Failed to build RecordTrie"); // Test basic functionality assert_eq!(trie.len(), 2); // Two unique keys assert!(!trie.is_empty()); assert!(trie.contains_key("apple")); assert!(trie.contains_key("banana")); assert!(!trie.contains_key("orange")); // Test retrieving records let apple_records = trie.get_u32_pairs("apple"); assert_eq!(apple_records.len(), 2); assert!(apple_records.contains(&(1, 2))); assert!(apple_records.contains(&(3, 4))); let banana_records = trie.get_u32_pairs("banana"); assert_eq!(banana_records.len(), 1); assert_eq!(banana_records[0], (5, 6)); // Test non-existent key let orange_records = trie.get_u32_pairs("orange"); assert_eq!(orange_records.len(), 0); } #[test] fn test_record_trie_u32_vectors() { let mut builder = RecordTrie::builder(); builder.insert_u32_vec("numbers", vec![1, 2, 3, 4]); builder.insert_u32_vec("numbers", vec![5, 6]); // Different length builder.insert_u32_vec("single", vec![42]); let trie = builder.build().expect("Failed to build RecordTrie"); let number_records = trie.get_u32_vecs("numbers"); assert_eq!(number_records.len(), 2); assert!(number_records.contains(&vec![1, 2, 3, 4])); assert!(number_records.contains(&vec![5, 6])); let single_records = trie.get_u32_vecs("single"); assert_eq!(single_records.len(), 1); assert_eq!(single_records[0], vec![42]); } #[test] fn test_record_trie_raw_bytes() { let mut builder = RecordTrie::builder(); builder.insert("text", b"hello".to_vec()); builder.insert("text", b"world".to_vec()); builder.insert("binary", vec![0xFF, 0x00, 0xAA, 0x55]); let trie = builder.build().expect("Failed to build RecordTrie"); let text_records = trie.get("text"); assert_eq!(text_records.len(), 2); assert!(text_records.iter().any(|&bytes| bytes == b"hello")); assert!(text_records.iter().any(|&bytes| bytes == b"world")); let binary_records = trie.get("binary"); assert_eq!(binary_records.len(), 1); assert_eq!(binary_records[0], &vec![0xFF, 0x00, 0xAA, 0x55]); } #[test] fn test_record_trie_prefix_search() { let mut builder = RecordTrie::builder(); builder.insert_u32_pair("app", (1, 1)); builder.insert_u32_pair("apple", (2, 2)); builder.insert_u32_pair("application", (3, 3)); builder.insert_u32_pair("banana", (4, 4)); let trie = builder.build().expect("Failed to build RecordTrie"); // Test keys with prefix let app_keys = trie.keys_with_prefix("app"); assert_eq!(app_keys.len(), 3); assert!(app_keys.contains(&"app".to_string())); assert!(app_keys.contains(&"apple".to_string())); assert!(app_keys.contains(&"application".to_string())); // Test prefixes of a key let prefixes = trie.prefixes_of("application"); assert!(prefixes.contains(&"app".to_string())); assert!(prefixes.contains(&"application".to_string())); } #[test] fn test_record_trie_empty() { let builder = RecordTrie::builder(); let trie = builder.build().expect("Failed to build empty RecordTrie"); assert_eq!(trie.len(), 0); assert!(trie.is_empty()); assert!(!trie.contains_key("anything")); assert_eq!(trie.get("anything").len(), 0); } }