diff --git a/Cargo.lock b/Cargo.lock index 6d05c54..4f945ac 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -157,7 +157,7 @@ checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94" [[package]] name = "marisa-rs" -version = "0.1.0" +version = "0.1.1" dependencies = [ "bindgen", "cc", diff --git a/Cargo.toml b/Cargo.toml index 833fc91..ea69eba 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "marisa-rs" -version = "0.1.0" +version = "0.1.1" edition = "2021" description = "Safe Rust wrapper for the marisa-trie C++ library" license = "LGPL-2.1-or-later" diff --git a/README.md b/README.md index b9ace87..f2065fa 100644 --- a/README.md +++ b/README.md @@ -57,11 +57,39 @@ keyset.push("car"); keyset.push("card"); keyset.push("care"); -// Build the trie +// Build the trie let mut trie = Trie::new(); trie.build(&mut keyset)?; ``` +### Saving and Loading Tries + +```rust +use marisa_rs::{Keyset, Trie}; + +// Build a trie +let mut keyset = Keyset::new(); +keyset.push("hello"); +keyset.push("world"); + +let mut trie = Trie::new(); +trie.build(&mut keyset)?; + +// Save the trie to a file +trie.save("my_trie.marisa")?; + +// Load the trie from a file +let mut loaded_trie = Trie::new(); +loaded_trie.load("my_trie.marisa")?; + +// Or use memory mapping for better performance with large tries +let mut mmapped_trie = Trie::new(); +mmapped_trie.mmap("my_trie.marisa")?; + +// Check the serialized size before saving +println!("Trie size: {} bytes", trie.io_size()); +``` + ### Lookup Operations ```rust @@ -128,6 +156,11 @@ trie.build(&mut keyset)?; - `trie.predictive_search(query, callback)` - Find all keys that start with query - `trie.size()` - Get the number of keys in the trie - `trie.is_empty()` - Check if the trie is empty +- `trie.save(path)` - Save the trie to a file (returns `Result<(), &str>`) +- `trie.load(path)` - Load a trie from a file (returns `Result<(), &str>`) +- `trie.mmap(path)` - Memory-map a trie file for efficient read-only access (returns `Result<(), &str>`) +- `trie.io_size()` - Get the serialized size of the trie in bytes +- `trie.clear()` - Clear the trie, removing all keys (returns `Result<(), &str>`) ## Japanese Text Example diff --git a/src/lib.rs b/src/lib.rs index 4b1a656..24bc247 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,8 @@ //! ``` use std::slice; +use std::path::Path; +use std::ffi::CString; mod bindings { include!(concat!(env!("OUT_DIR"), "/bindings.rs")); @@ -198,13 +200,124 @@ impl Trie { /// Returns the number of keys stored in the trie. pub fn size(&self) -> usize { - unsafe { marisa_trie_size(self.inner) } + unsafe { + let mut size: usize = 0; + if marisa_trie_size(self.inner, &mut size) == 1 { + size + } else { + 0 // Return 0 if there's an error + } + } } /// Returns true if the trie is empty. pub fn is_empty(&self) -> bool { self.size() == 0 } + + /// Saves the trie to a file. + /// + /// # Arguments + /// + /// * `path` - The file path to save the trie to + /// + /// # Errors + /// + /// Returns an error if the trie cannot be saved to the file. + pub fn save>(&self, path: P) -> Result<(), &'static str> { + let path_str = path.as_ref().to_string_lossy(); + let c_path = CString::new(path_str.as_ref()) + .map_err(|_| "Invalid path: contains null bytes")?; + + unsafe { + if marisa_trie_save(self.inner, c_path.as_ptr()) == 1 { + Ok(()) + } else { + Err("Failed to save trie") + } + } + } + + /// Loads a trie from a file. + /// + /// # Arguments + /// + /// * `path` - The file path to load the trie from + /// + /// # Errors + /// + /// Returns an error if the trie cannot be loaded from the file. + pub fn load>(&mut self, path: P) -> Result<(), &'static str> { + let path_str = path.as_ref().to_string_lossy(); + let c_path = CString::new(path_str.as_ref()) + .map_err(|_| "Invalid path: contains null bytes")?; + + unsafe { + if marisa_trie_load(self.inner, c_path.as_ptr()) == 1 { + Ok(()) + } else { + Err("Failed to load trie") + } + } + } + + /// Memory-maps a trie file for read-only access. + /// + /// This is more memory-efficient than loading the entire trie into memory. + /// + /// # Arguments + /// + /// * `path` - The file path to memory-map the trie from + /// + /// # Errors + /// + /// Returns an error if the trie cannot be memory-mapped from the file. + pub fn mmap>(&mut self, path: P) -> Result<(), &'static str> { + let path_str = path.as_ref().to_string_lossy(); + let c_path = CString::new(path_str.as_ref()) + .map_err(|_| "Invalid path: contains null bytes")?; + + unsafe { + if marisa_trie_mmap(self.inner, c_path.as_ptr()) == 1 { + Ok(()) + } else { + Err("Failed to memory-map trie") + } + } + } + + /// Returns the size of the trie when serialized. + /// + /// This can be useful for determining the storage requirements + /// before saving the trie to a file. + pub fn io_size(&self) -> usize { + unsafe { + let mut size: usize = 0; + if marisa_trie_io_size(self.inner, &mut size) == 1 { + size + } else { + 0 // Return 0 if there's an error + } + } + } + + /// Clears the trie, removing all keys. + /// + /// After calling this method, the trie will be empty and + /// must be rebuilt from a keyset before it can be used again. + /// + /// # Errors + /// + /// Returns an error if the trie cannot be cleared. + pub fn clear(&mut self) -> Result<(), &'static str> { + unsafe { + if marisa_trie_clear(self.inner) == 1 { + Ok(()) + } else { + Err("Failed to clear trie") + } + } + } } impl Drop for Trie { @@ -330,4 +443,101 @@ mod tests { assert!(trie.is_empty()); assert_eq!(trie.size(), 0); } + + #[test] + fn test_save_and_load() { + use std::fs; + + let test_file = "test_trie.marisa"; + + // Clean up from any previous test + let _ = fs::remove_file(test_file); + + // Build a trie + let mut keyset = Keyset::new(); + keyset.push("apple"); + keyset.push("application"); + keyset.push("apply"); + + let mut original_trie = Trie::new(); + original_trie.build(&mut keyset).expect("Failed to build trie"); + + // Test io_size + assert!(original_trie.io_size() > 0); + + // Save the trie + original_trie.save(test_file).expect("Failed to save trie"); + + // Load the trie + let mut loaded_trie = Trie::new(); + loaded_trie.load(test_file).expect("Failed to load trie"); + + // Verify the loaded trie works the same + assert_eq!(loaded_trie.size(), original_trie.size()); + assert!(loaded_trie.lookup("apple").is_some()); + assert!(loaded_trie.lookup("application").is_some()); + assert!(loaded_trie.lookup("apply").is_some()); + assert!(loaded_trie.lookup("banana").is_none()); + + // Test reverse lookup + if let Some(id) = loaded_trie.lookup("apple") { + assert_eq!(loaded_trie.reverse_lookup(id).unwrap(), "apple"); + } + + // Clean up + let _ = fs::remove_file(test_file); + } + + #[test] + fn test_mmap() { + use std::fs; + + let test_file = "test_mmap_trie.marisa"; + + // Clean up from any previous test + let _ = fs::remove_file(test_file); + + // Build and save a trie + let mut keyset = Keyset::new(); + keyset.push("memory"); + keyset.push("mapped"); + keyset.push("trie"); + + let mut original_trie = Trie::new(); + original_trie.build(&mut keyset).expect("Failed to build trie"); + original_trie.save(test_file).expect("Failed to save trie"); + + // Memory-map the trie + let mut mmapped_trie = Trie::new(); + mmapped_trie.mmap(test_file).expect("Failed to mmap trie"); + + // Verify the memory-mapped trie works + assert_eq!(mmapped_trie.size(), 3); + assert!(mmapped_trie.lookup("memory").is_some()); + assert!(mmapped_trie.lookup("mapped").is_some()); + assert!(mmapped_trie.lookup("trie").is_some()); + assert!(mmapped_trie.lookup("nonexistent").is_none()); + + // Clean up + let _ = fs::remove_file(test_file); + } + + #[test] + fn test_clear() { + let mut keyset = Keyset::new(); + keyset.push("test"); + + let mut trie = Trie::new(); + trie.build(&mut keyset).expect("Failed to build trie"); + + assert_eq!(trie.size(), 1); + assert!(!trie.is_empty()); + + // Clear the trie + trie.clear().expect("Failed to clear trie"); + + assert_eq!(trie.size(), 0); + assert!(trie.is_empty()); + assert!(trie.lookup("test").is_none()); + } } \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 4bffc74..7da8725 100644 --- a/src/main.rs +++ b/src/main.rs @@ -47,4 +47,124 @@ fn main() { trie.predictive_search("app", |key, id| { println!(" Found: '{}' (ID: {})", key, id); }); + + test_save_load_functionality(&trie); +} + +fn test_save_load_functionality(original_trie: &Trie) { + println!("\n=== Save/Load Functionality Test ==="); + + let test_file = "demo_trie.marisa"; + + // Test io_size + let trie_size = original_trie.io_size(); + println!("Original trie serialized size: {} bytes", trie_size); + + // Test save + println!("\nSaving trie to '{}'...", test_file); + match original_trie.save(test_file) { + Ok(()) => println!("✓ Trie saved successfully!"), + Err(e) => { + println!("✗ Failed to save trie: {}", e); + return; + } + } + + // Test load + println!("\nLoading trie from '{}'...", test_file); + let mut loaded_trie = Trie::new(); + match loaded_trie.load(test_file) { + Ok(()) => println!("✓ Trie loaded successfully! Size: {}", loaded_trie.size()), + Err(e) => { + println!("✗ Failed to load trie: {}", e); + return; + } + } + + // Verify loaded trie functionality + println!("\n--- Verifying Loaded Trie ---"); + let test_words = vec!["apple", "application", "banana"]; + for word in &test_words { + let original_result = original_trie.lookup(word); + let loaded_result = loaded_trie.lookup(word); + match (original_result, loaded_result) { + (Some(orig_id), Some(load_id)) if orig_id == load_id => { + println!("✓ '{}' lookup matches (ID: {})", word, orig_id); + } + (None, None) => { + println!("✓ '{}' not found in both tries", word); + } + _ => { + println!("✗ '{}' lookup mismatch: orig={:?}, loaded={:?}", + word, original_result, loaded_result); + } + } + } + + // Test memory mapping + println!("\n--- Testing Memory Mapping ---"); + let mut mmapped_trie = Trie::new(); + match mmapped_trie.mmap(test_file) { + Ok(()) => { + println!("✓ Trie memory-mapped successfully! Size: {}", mmapped_trie.size()); + + // Test a lookup on memory-mapped trie + if let Some(id) = mmapped_trie.lookup("apple") { + match mmapped_trie.reverse_lookup(id) { + Ok(word) => println!("✓ Memory-mapped reverse lookup: ID {} -> '{}'", id, word), + Err(e) => println!("✗ Memory-mapped reverse lookup failed: {}", e), + } + } + } + Err(e) => println!("✗ Failed to memory-map trie: {}", e), + } + + // Test clear functionality + println!("\n--- Testing Clear Functionality ---"); + let mut test_trie = Trie::new(); + let mut keyset = Keyset::new(); + keyset.push("test"); + keyset.push("clear"); + + if test_trie.build(&mut keyset).is_ok() { + println!("Created test trie with {} keys", test_trie.size()); + + // Test lookup before clear + let before_lookup = test_trie.lookup("test").is_some(); + println!("Before clear - 'test' found: {}", before_lookup); + + // Clear the trie + match test_trie.clear() { + Ok(()) => println!("✓ Trie clear operation successful"), + Err(e) => { + println!("✗ Failed to clear trie: {}", e); + return; + } + } + + // Test after clear + println!("Testing trie state after clear..."); + + let after_size = test_trie.size(); + let is_empty = test_trie.is_empty(); + let after_lookup = test_trie.lookup("test").is_some(); + + println!(" Size after clear: {}", after_size); + println!(" Is empty after clear: {}", is_empty); + println!(" 'test' found after clear: {}", after_lookup); + + if after_size == 0 && is_empty && !after_lookup { + println!("✓ Clear functionality works correctly!"); + } else { + println!("⚠ Clear may not have worked as expected"); + } + } else { + println!("✗ Failed to build test trie for clear functionality"); + } + + // Clean up + match std::fs::remove_file(test_file) { + Ok(()) => println!("\n✓ Cleanup: Removed test file '{}'", test_file), + Err(e) => println!("\n⚠ Warning: Could not remove test file '{}': {}", test_file, e), + } } diff --git a/wrapper.cpp b/wrapper.cpp index 7fea9d2..8143b4a 100644 --- a/wrapper.cpp +++ b/wrapper.cpp @@ -82,9 +82,64 @@ int marisa_trie_predictive_search(const MarisaTrie* trie, MarisaAgent* agent) { } } -size_t marisa_trie_size(const MarisaTrie* trie) { - const marisa::Trie* tr = reinterpret_cast(trie); - return tr->size(); +int marisa_trie_size(const MarisaTrie* trie, size_t* size) { + try { + const marisa::Trie* tr = reinterpret_cast(trie); + *size = tr->size(); + return 1; + } catch (...) { + return 0; + } +} + +int marisa_trie_save(const MarisaTrie* trie, const char* filename) { + try { + const marisa::Trie* tr = reinterpret_cast(trie); + tr->save(filename); + return 1; + } catch (...) { + return 0; + } +} + +int marisa_trie_load(MarisaTrie* trie, const char* filename) { + try { + marisa::Trie* tr = reinterpret_cast(trie); + tr->load(filename); + return 1; + } catch (...) { + return 0; + } +} + +int marisa_trie_mmap(MarisaTrie* trie, const char* filename) { + try { + marisa::Trie* tr = reinterpret_cast(trie); + tr->mmap(filename); + return 1; + } catch (...) { + return 0; + } +} + +int marisa_trie_io_size(const MarisaTrie* trie, size_t* size) { + try { + const marisa::Trie* tr = reinterpret_cast(trie); + *size = tr->io_size(); + return 1; + } catch (...) { + return 0; + } +} + +int marisa_trie_clear(MarisaTrie* trie) { + try { + marisa::Trie* tr = reinterpret_cast(trie); + tr->clear(); + return 1; + } catch (...) { + return 0; + } } MarisaAgent* marisa_agent_new() { diff --git a/wrapper.h b/wrapper.h index 9289af5..09e044a 100644 --- a/wrapper.h +++ b/wrapper.h @@ -25,7 +25,12 @@ int marisa_trie_lookup(const MarisaTrie* trie, MarisaAgent* agent); int marisa_trie_reverse_lookup(const MarisaTrie* trie, MarisaAgent* agent); int marisa_trie_common_prefix_search(const MarisaTrie* trie, MarisaAgent* agent); int marisa_trie_predictive_search(const MarisaTrie* trie, MarisaAgent* agent); -size_t marisa_trie_size(const MarisaTrie* trie); +int marisa_trie_size(const MarisaTrie* trie, size_t* size); +int marisa_trie_save(const MarisaTrie* trie, const char* filename); +int marisa_trie_load(MarisaTrie* trie, const char* filename); +int marisa_trie_mmap(MarisaTrie* trie, const char* filename); +int marisa_trie_io_size(const MarisaTrie* trie, size_t* size); +int marisa_trie_clear(MarisaTrie* trie); // Agent functions MarisaAgent* marisa_agent_new();