From 46a1077d998f54d0ad782d83cf3204f80a985b49 Mon Sep 17 00:00:00 2001
From: Soma Nakamura
Date: Thu, 24 Jul 2025 04:25:06 +0900
Subject: [PATCH] update save functions
---
Cargo.lock | 2 +-
Cargo.toml | 2 +-
README.md | 35 ++++++++-
src/lib.rs | 212 +++++++++++++++++++++++++++++++++++++++++++++++++++-
src/main.rs | 120 +++++++++++++++++++++++++++++
wrapper.cpp | 61 ++++++++++++++-
wrapper.h | 7 +-
7 files changed, 431 insertions(+), 8 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 6d05c54..4f945ac 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -157,7 +157,7 @@ checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]]
name = "marisa-rs"
-version = "0.1.0"
+version = "0.1.1"
dependencies = [
"bindgen",
"cc",
diff --git a/Cargo.toml b/Cargo.toml
index 833fc91..ea69eba 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,6 +1,6 @@
[package]
name = "marisa-rs"
-version = "0.1.0"
+version = "0.1.1"
edition = "2021"
description = "Safe Rust wrapper for the marisa-trie C++ library"
license = "LGPL-2.1-or-later"
diff --git a/README.md b/README.md
index b9ace87..f2065fa 100644
--- a/README.md
+++ b/README.md
@@ -57,11 +57,39 @@ keyset.push("car");
keyset.push("card");
keyset.push("care");
-// Build the trie
+// Build the trie
let mut trie = Trie::new();
trie.build(&mut keyset)?;
```
+### Saving and Loading Tries
+
+```rust
+use marisa_rs::{Keyset, Trie};
+
+// Build a trie
+let mut keyset = Keyset::new();
+keyset.push("hello");
+keyset.push("world");
+
+let mut trie = Trie::new();
+trie.build(&mut keyset)?;
+
+// Save the trie to a file
+trie.save("my_trie.marisa")?;
+
+// Load the trie from a file
+let mut loaded_trie = Trie::new();
+loaded_trie.load("my_trie.marisa")?;
+
+// Or use memory mapping for better performance with large tries
+let mut mmapped_trie = Trie::new();
+mmapped_trie.mmap("my_trie.marisa")?;
+
+// Check the serialized size before saving
+println!("Trie size: {} bytes", trie.io_size());
+```
+
### Lookup Operations
```rust
@@ -128,6 +156,11 @@ trie.build(&mut keyset)?;
- `trie.predictive_search(query, callback)` - Find all keys that start with query
- `trie.size()` - Get the number of keys in the trie
- `trie.is_empty()` - Check if the trie is empty
+- `trie.save(path)` - Save the trie to a file (returns `Result<(), &str>`)
+- `trie.load(path)` - Load a trie from a file (returns `Result<(), &str>`)
+- `trie.mmap(path)` - Memory-map a trie file for efficient read-only access (returns `Result<(), &str>`)
+- `trie.io_size()` - Get the serialized size of the trie in bytes
+- `trie.clear()` - Clear the trie, removing all keys (returns `Result<(), &str>`)
## Japanese Text Example
diff --git a/src/lib.rs b/src/lib.rs
index 4b1a656..24bc247 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -29,6 +29,8 @@
//! ```
use std::slice;
+use std::path::Path;
+use std::ffi::CString;
mod bindings {
include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
@@ -198,13 +200,124 @@ impl Trie {
/// Returns the number of keys stored in the trie.
pub fn size(&self) -> usize {
- unsafe { marisa_trie_size(self.inner) }
+ unsafe {
+ let mut size: usize = 0;
+ if marisa_trie_size(self.inner, &mut size) == 1 {
+ size
+ } else {
+ 0 // Return 0 if there's an error
+ }
+ }
}
/// Returns true if the trie is empty.
pub fn is_empty(&self) -> bool {
self.size() == 0
}
+
+ /// Saves the trie to a file.
+ ///
+ /// # Arguments
+ ///
+ /// * `path` - The file path to save the trie to
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if the trie cannot be saved to the file.
+ pub fn save>(&self, path: P) -> Result<(), &'static str> {
+ let path_str = path.as_ref().to_string_lossy();
+ let c_path = CString::new(path_str.as_ref())
+ .map_err(|_| "Invalid path: contains null bytes")?;
+
+ unsafe {
+ if marisa_trie_save(self.inner, c_path.as_ptr()) == 1 {
+ Ok(())
+ } else {
+ Err("Failed to save trie")
+ }
+ }
+ }
+
+ /// Loads a trie from a file.
+ ///
+ /// # Arguments
+ ///
+ /// * `path` - The file path to load the trie from
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if the trie cannot be loaded from the file.
+ pub fn load>(&mut self, path: P) -> Result<(), &'static str> {
+ let path_str = path.as_ref().to_string_lossy();
+ let c_path = CString::new(path_str.as_ref())
+ .map_err(|_| "Invalid path: contains null bytes")?;
+
+ unsafe {
+ if marisa_trie_load(self.inner, c_path.as_ptr()) == 1 {
+ Ok(())
+ } else {
+ Err("Failed to load trie")
+ }
+ }
+ }
+
+ /// Memory-maps a trie file for read-only access.
+ ///
+ /// This is more memory-efficient than loading the entire trie into memory.
+ ///
+ /// # Arguments
+ ///
+ /// * `path` - The file path to memory-map the trie from
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if the trie cannot be memory-mapped from the file.
+ pub fn mmap>(&mut self, path: P) -> Result<(), &'static str> {
+ let path_str = path.as_ref().to_string_lossy();
+ let c_path = CString::new(path_str.as_ref())
+ .map_err(|_| "Invalid path: contains null bytes")?;
+
+ unsafe {
+ if marisa_trie_mmap(self.inner, c_path.as_ptr()) == 1 {
+ Ok(())
+ } else {
+ Err("Failed to memory-map trie")
+ }
+ }
+ }
+
+ /// Returns the size of the trie when serialized.
+ ///
+ /// This can be useful for determining the storage requirements
+ /// before saving the trie to a file.
+ pub fn io_size(&self) -> usize {
+ unsafe {
+ let mut size: usize = 0;
+ if marisa_trie_io_size(self.inner, &mut size) == 1 {
+ size
+ } else {
+ 0 // Return 0 if there's an error
+ }
+ }
+ }
+
+ /// Clears the trie, removing all keys.
+ ///
+ /// After calling this method, the trie will be empty and
+ /// must be rebuilt from a keyset before it can be used again.
+ ///
+ /// # Errors
+ ///
+ /// Returns an error if the trie cannot be cleared.
+ pub fn clear(&mut self) -> Result<(), &'static str> {
+ unsafe {
+ if marisa_trie_clear(self.inner) == 1 {
+ Ok(())
+ } else {
+ Err("Failed to clear trie")
+ }
+ }
+ }
}
impl Drop for Trie {
@@ -330,4 +443,101 @@ mod tests {
assert!(trie.is_empty());
assert_eq!(trie.size(), 0);
}
+
+ #[test]
+ fn test_save_and_load() {
+ use std::fs;
+
+ let test_file = "test_trie.marisa";
+
+ // Clean up from any previous test
+ let _ = fs::remove_file(test_file);
+
+ // Build a trie
+ let mut keyset = Keyset::new();
+ keyset.push("apple");
+ keyset.push("application");
+ keyset.push("apply");
+
+ let mut original_trie = Trie::new();
+ original_trie.build(&mut keyset).expect("Failed to build trie");
+
+ // Test io_size
+ assert!(original_trie.io_size() > 0);
+
+ // Save the trie
+ original_trie.save(test_file).expect("Failed to save trie");
+
+ // Load the trie
+ let mut loaded_trie = Trie::new();
+ loaded_trie.load(test_file).expect("Failed to load trie");
+
+ // Verify the loaded trie works the same
+ assert_eq!(loaded_trie.size(), original_trie.size());
+ assert!(loaded_trie.lookup("apple").is_some());
+ assert!(loaded_trie.lookup("application").is_some());
+ assert!(loaded_trie.lookup("apply").is_some());
+ assert!(loaded_trie.lookup("banana").is_none());
+
+ // Test reverse lookup
+ if let Some(id) = loaded_trie.lookup("apple") {
+ assert_eq!(loaded_trie.reverse_lookup(id).unwrap(), "apple");
+ }
+
+ // Clean up
+ let _ = fs::remove_file(test_file);
+ }
+
+ #[test]
+ fn test_mmap() {
+ use std::fs;
+
+ let test_file = "test_mmap_trie.marisa";
+
+ // Clean up from any previous test
+ let _ = fs::remove_file(test_file);
+
+ // Build and save a trie
+ let mut keyset = Keyset::new();
+ keyset.push("memory");
+ keyset.push("mapped");
+ keyset.push("trie");
+
+ let mut original_trie = Trie::new();
+ original_trie.build(&mut keyset).expect("Failed to build trie");
+ original_trie.save(test_file).expect("Failed to save trie");
+
+ // Memory-map the trie
+ let mut mmapped_trie = Trie::new();
+ mmapped_trie.mmap(test_file).expect("Failed to mmap trie");
+
+ // Verify the memory-mapped trie works
+ assert_eq!(mmapped_trie.size(), 3);
+ assert!(mmapped_trie.lookup("memory").is_some());
+ assert!(mmapped_trie.lookup("mapped").is_some());
+ assert!(mmapped_trie.lookup("trie").is_some());
+ assert!(mmapped_trie.lookup("nonexistent").is_none());
+
+ // Clean up
+ let _ = fs::remove_file(test_file);
+ }
+
+ #[test]
+ fn test_clear() {
+ let mut keyset = Keyset::new();
+ keyset.push("test");
+
+ let mut trie = Trie::new();
+ trie.build(&mut keyset).expect("Failed to build trie");
+
+ assert_eq!(trie.size(), 1);
+ assert!(!trie.is_empty());
+
+ // Clear the trie
+ trie.clear().expect("Failed to clear trie");
+
+ assert_eq!(trie.size(), 0);
+ assert!(trie.is_empty());
+ assert!(trie.lookup("test").is_none());
+ }
}
\ No newline at end of file
diff --git a/src/main.rs b/src/main.rs
index 4bffc74..7da8725 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -47,4 +47,124 @@ fn main() {
trie.predictive_search("app", |key, id| {
println!(" Found: '{}' (ID: {})", key, id);
});
+
+ test_save_load_functionality(&trie);
+}
+
+fn test_save_load_functionality(original_trie: &Trie) {
+ println!("\n=== Save/Load Functionality Test ===");
+
+ let test_file = "demo_trie.marisa";
+
+ // Test io_size
+ let trie_size = original_trie.io_size();
+ println!("Original trie serialized size: {} bytes", trie_size);
+
+ // Test save
+ println!("\nSaving trie to '{}'...", test_file);
+ match original_trie.save(test_file) {
+ Ok(()) => println!("✓ Trie saved successfully!"),
+ Err(e) => {
+ println!("✗ Failed to save trie: {}", e);
+ return;
+ }
+ }
+
+ // Test load
+ println!("\nLoading trie from '{}'...", test_file);
+ let mut loaded_trie = Trie::new();
+ match loaded_trie.load(test_file) {
+ Ok(()) => println!("✓ Trie loaded successfully! Size: {}", loaded_trie.size()),
+ Err(e) => {
+ println!("✗ Failed to load trie: {}", e);
+ return;
+ }
+ }
+
+ // Verify loaded trie functionality
+ println!("\n--- Verifying Loaded Trie ---");
+ let test_words = vec!["apple", "application", "banana"];
+ for word in &test_words {
+ let original_result = original_trie.lookup(word);
+ let loaded_result = loaded_trie.lookup(word);
+ match (original_result, loaded_result) {
+ (Some(orig_id), Some(load_id)) if orig_id == load_id => {
+ println!("✓ '{}' lookup matches (ID: {})", word, orig_id);
+ }
+ (None, None) => {
+ println!("✓ '{}' not found in both tries", word);
+ }
+ _ => {
+ println!("✗ '{}' lookup mismatch: orig={:?}, loaded={:?}",
+ word, original_result, loaded_result);
+ }
+ }
+ }
+
+ // Test memory mapping
+ println!("\n--- Testing Memory Mapping ---");
+ let mut mmapped_trie = Trie::new();
+ match mmapped_trie.mmap(test_file) {
+ Ok(()) => {
+ println!("✓ Trie memory-mapped successfully! Size: {}", mmapped_trie.size());
+
+ // Test a lookup on memory-mapped trie
+ if let Some(id) = mmapped_trie.lookup("apple") {
+ match mmapped_trie.reverse_lookup(id) {
+ Ok(word) => println!("✓ Memory-mapped reverse lookup: ID {} -> '{}'", id, word),
+ Err(e) => println!("✗ Memory-mapped reverse lookup failed: {}", e),
+ }
+ }
+ }
+ Err(e) => println!("✗ Failed to memory-map trie: {}", e),
+ }
+
+ // Test clear functionality
+ println!("\n--- Testing Clear Functionality ---");
+ let mut test_trie = Trie::new();
+ let mut keyset = Keyset::new();
+ keyset.push("test");
+ keyset.push("clear");
+
+ if test_trie.build(&mut keyset).is_ok() {
+ println!("Created test trie with {} keys", test_trie.size());
+
+ // Test lookup before clear
+ let before_lookup = test_trie.lookup("test").is_some();
+ println!("Before clear - 'test' found: {}", before_lookup);
+
+ // Clear the trie
+ match test_trie.clear() {
+ Ok(()) => println!("✓ Trie clear operation successful"),
+ Err(e) => {
+ println!("✗ Failed to clear trie: {}", e);
+ return;
+ }
+ }
+
+ // Test after clear
+ println!("Testing trie state after clear...");
+
+ let after_size = test_trie.size();
+ let is_empty = test_trie.is_empty();
+ let after_lookup = test_trie.lookup("test").is_some();
+
+ println!(" Size after clear: {}", after_size);
+ println!(" Is empty after clear: {}", is_empty);
+ println!(" 'test' found after clear: {}", after_lookup);
+
+ if after_size == 0 && is_empty && !after_lookup {
+ println!("✓ Clear functionality works correctly!");
+ } else {
+ println!("⚠ Clear may not have worked as expected");
+ }
+ } else {
+ println!("✗ Failed to build test trie for clear functionality");
+ }
+
+ // Clean up
+ match std::fs::remove_file(test_file) {
+ Ok(()) => println!("\n✓ Cleanup: Removed test file '{}'", test_file),
+ Err(e) => println!("\n⚠ Warning: Could not remove test file '{}': {}", test_file, e),
+ }
}
diff --git a/wrapper.cpp b/wrapper.cpp
index 7fea9d2..8143b4a 100644
--- a/wrapper.cpp
+++ b/wrapper.cpp
@@ -82,9 +82,64 @@ int marisa_trie_predictive_search(const MarisaTrie* trie, MarisaAgent* agent) {
}
}
-size_t marisa_trie_size(const MarisaTrie* trie) {
- const marisa::Trie* tr = reinterpret_cast(trie);
- return tr->size();
+int marisa_trie_size(const MarisaTrie* trie, size_t* size) {
+ try {
+ const marisa::Trie* tr = reinterpret_cast(trie);
+ *size = tr->size();
+ return 1;
+ } catch (...) {
+ return 0;
+ }
+}
+
+int marisa_trie_save(const MarisaTrie* trie, const char* filename) {
+ try {
+ const marisa::Trie* tr = reinterpret_cast(trie);
+ tr->save(filename);
+ return 1;
+ } catch (...) {
+ return 0;
+ }
+}
+
+int marisa_trie_load(MarisaTrie* trie, const char* filename) {
+ try {
+ marisa::Trie* tr = reinterpret_cast(trie);
+ tr->load(filename);
+ return 1;
+ } catch (...) {
+ return 0;
+ }
+}
+
+int marisa_trie_mmap(MarisaTrie* trie, const char* filename) {
+ try {
+ marisa::Trie* tr = reinterpret_cast(trie);
+ tr->mmap(filename);
+ return 1;
+ } catch (...) {
+ return 0;
+ }
+}
+
+int marisa_trie_io_size(const MarisaTrie* trie, size_t* size) {
+ try {
+ const marisa::Trie* tr = reinterpret_cast(trie);
+ *size = tr->io_size();
+ return 1;
+ } catch (...) {
+ return 0;
+ }
+}
+
+int marisa_trie_clear(MarisaTrie* trie) {
+ try {
+ marisa::Trie* tr = reinterpret_cast(trie);
+ tr->clear();
+ return 1;
+ } catch (...) {
+ return 0;
+ }
}
MarisaAgent* marisa_agent_new() {
diff --git a/wrapper.h b/wrapper.h
index 9289af5..09e044a 100644
--- a/wrapper.h
+++ b/wrapper.h
@@ -25,7 +25,12 @@ int marisa_trie_lookup(const MarisaTrie* trie, MarisaAgent* agent);
int marisa_trie_reverse_lookup(const MarisaTrie* trie, MarisaAgent* agent);
int marisa_trie_common_prefix_search(const MarisaTrie* trie, MarisaAgent* agent);
int marisa_trie_predictive_search(const MarisaTrie* trie, MarisaAgent* agent);
-size_t marisa_trie_size(const MarisaTrie* trie);
+int marisa_trie_size(const MarisaTrie* trie, size_t* size);
+int marisa_trie_save(const MarisaTrie* trie, const char* filename);
+int marisa_trie_load(MarisaTrie* trie, const char* filename);
+int marisa_trie_mmap(MarisaTrie* trie, const char* filename);
+int marisa_trie_io_size(const MarisaTrie* trie, size_t* size);
+int marisa_trie_clear(MarisaTrie* trie);
// Agent functions
MarisaAgent* marisa_agent_new();