update save functions

This commit is contained in:
Soma Nakamura 2025-07-24 04:25:06 +09:00
parent 0d4d05a2f3
commit 46a1077d99
7 changed files with 431 additions and 8 deletions

2
Cargo.lock generated
View file

@ -157,7 +157,7 @@ checksum = "13dc2df351e3202783a1fe0d44375f7295ffb4049267b0f3018346dc122a1d94"
[[package]] [[package]]
name = "marisa-rs" name = "marisa-rs"
version = "0.1.0" version = "0.1.1"
dependencies = [ dependencies = [
"bindgen", "bindgen",
"cc", "cc",

View file

@ -1,6 +1,6 @@
[package] [package]
name = "marisa-rs" name = "marisa-rs"
version = "0.1.0" version = "0.1.1"
edition = "2021" edition = "2021"
description = "Safe Rust wrapper for the marisa-trie C++ library" description = "Safe Rust wrapper for the marisa-trie C++ library"
license = "LGPL-2.1-or-later" license = "LGPL-2.1-or-later"

View file

@ -57,11 +57,39 @@ keyset.push("car");
keyset.push("card"); keyset.push("card");
keyset.push("care"); keyset.push("care");
// Build the trie // Build the trie
let mut trie = Trie::new(); let mut trie = Trie::new();
trie.build(&mut keyset)?; trie.build(&mut keyset)?;
``` ```
### Saving and Loading Tries
```rust
use marisa_rs::{Keyset, Trie};
// Build a trie
let mut keyset = Keyset::new();
keyset.push("hello");
keyset.push("world");
let mut trie = Trie::new();
trie.build(&mut keyset)?;
// Save the trie to a file
trie.save("my_trie.marisa")?;
// Load the trie from a file
let mut loaded_trie = Trie::new();
loaded_trie.load("my_trie.marisa")?;
// Or use memory mapping for better performance with large tries
let mut mmapped_trie = Trie::new();
mmapped_trie.mmap("my_trie.marisa")?;
// Check the serialized size before saving
println!("Trie size: {} bytes", trie.io_size());
```
### Lookup Operations ### Lookup Operations
```rust ```rust
@ -128,6 +156,11 @@ trie.build(&mut keyset)?;
- `trie.predictive_search(query, callback)` - Find all keys that start with query - `trie.predictive_search(query, callback)` - Find all keys that start with query
- `trie.size()` - Get the number of keys in the trie - `trie.size()` - Get the number of keys in the trie
- `trie.is_empty()` - Check if the trie is empty - `trie.is_empty()` - Check if the trie is empty
- `trie.save(path)` - Save the trie to a file (returns `Result<(), &str>`)
- `trie.load(path)` - Load a trie from a file (returns `Result<(), &str>`)
- `trie.mmap(path)` - Memory-map a trie file for efficient read-only access (returns `Result<(), &str>`)
- `trie.io_size()` - Get the serialized size of the trie in bytes
- `trie.clear()` - Clear the trie, removing all keys (returns `Result<(), &str>`)
## Japanese Text Example ## Japanese Text Example

View file

@ -29,6 +29,8 @@
//! ``` //! ```
use std::slice; use std::slice;
use std::path::Path;
use std::ffi::CString;
mod bindings { mod bindings {
include!(concat!(env!("OUT_DIR"), "/bindings.rs")); include!(concat!(env!("OUT_DIR"), "/bindings.rs"));
@ -198,13 +200,124 @@ impl Trie {
/// Returns the number of keys stored in the trie. /// Returns the number of keys stored in the trie.
pub fn size(&self) -> usize { pub fn size(&self) -> usize {
unsafe { marisa_trie_size(self.inner) } unsafe {
let mut size: usize = 0;
if marisa_trie_size(self.inner, &mut size) == 1 {
size
} else {
0 // Return 0 if there's an error
}
}
} }
/// Returns true if the trie is empty. /// Returns true if the trie is empty.
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
self.size() == 0 self.size() == 0
} }
/// Saves the trie to a file.
///
/// # Arguments
///
/// * `path` - The file path to save the trie to
///
/// # Errors
///
/// Returns an error if the trie cannot be saved to the file.
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<(), &'static str> {
let path_str = path.as_ref().to_string_lossy();
let c_path = CString::new(path_str.as_ref())
.map_err(|_| "Invalid path: contains null bytes")?;
unsafe {
if marisa_trie_save(self.inner, c_path.as_ptr()) == 1 {
Ok(())
} else {
Err("Failed to save trie")
}
}
}
/// Loads a trie from a file.
///
/// # Arguments
///
/// * `path` - The file path to load the trie from
///
/// # Errors
///
/// Returns an error if the trie cannot be loaded from the file.
pub fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<(), &'static str> {
let path_str = path.as_ref().to_string_lossy();
let c_path = CString::new(path_str.as_ref())
.map_err(|_| "Invalid path: contains null bytes")?;
unsafe {
if marisa_trie_load(self.inner, c_path.as_ptr()) == 1 {
Ok(())
} else {
Err("Failed to load trie")
}
}
}
/// Memory-maps a trie file for read-only access.
///
/// This is more memory-efficient than loading the entire trie into memory.
///
/// # Arguments
///
/// * `path` - The file path to memory-map the trie from
///
/// # Errors
///
/// Returns an error if the trie cannot be memory-mapped from the file.
pub fn mmap<P: AsRef<Path>>(&mut self, path: P) -> Result<(), &'static str> {
let path_str = path.as_ref().to_string_lossy();
let c_path = CString::new(path_str.as_ref())
.map_err(|_| "Invalid path: contains null bytes")?;
unsafe {
if marisa_trie_mmap(self.inner, c_path.as_ptr()) == 1 {
Ok(())
} else {
Err("Failed to memory-map trie")
}
}
}
/// Returns the size of the trie when serialized.
///
/// This can be useful for determining the storage requirements
/// before saving the trie to a file.
pub fn io_size(&self) -> usize {
unsafe {
let mut size: usize = 0;
if marisa_trie_io_size(self.inner, &mut size) == 1 {
size
} else {
0 // Return 0 if there's an error
}
}
}
/// Clears the trie, removing all keys.
///
/// After calling this method, the trie will be empty and
/// must be rebuilt from a keyset before it can be used again.
///
/// # Errors
///
/// Returns an error if the trie cannot be cleared.
pub fn clear(&mut self) -> Result<(), &'static str> {
unsafe {
if marisa_trie_clear(self.inner) == 1 {
Ok(())
} else {
Err("Failed to clear trie")
}
}
}
} }
impl Drop for Trie { impl Drop for Trie {
@ -330,4 +443,101 @@ mod tests {
assert!(trie.is_empty()); assert!(trie.is_empty());
assert_eq!(trie.size(), 0); assert_eq!(trie.size(), 0);
} }
#[test]
fn test_save_and_load() {
use std::fs;
let test_file = "test_trie.marisa";
// Clean up from any previous test
let _ = fs::remove_file(test_file);
// Build a trie
let mut keyset = Keyset::new();
keyset.push("apple");
keyset.push("application");
keyset.push("apply");
let mut original_trie = Trie::new();
original_trie.build(&mut keyset).expect("Failed to build trie");
// Test io_size
assert!(original_trie.io_size() > 0);
// Save the trie
original_trie.save(test_file).expect("Failed to save trie");
// Load the trie
let mut loaded_trie = Trie::new();
loaded_trie.load(test_file).expect("Failed to load trie");
// Verify the loaded trie works the same
assert_eq!(loaded_trie.size(), original_trie.size());
assert!(loaded_trie.lookup("apple").is_some());
assert!(loaded_trie.lookup("application").is_some());
assert!(loaded_trie.lookup("apply").is_some());
assert!(loaded_trie.lookup("banana").is_none());
// Test reverse lookup
if let Some(id) = loaded_trie.lookup("apple") {
assert_eq!(loaded_trie.reverse_lookup(id).unwrap(), "apple");
}
// Clean up
let _ = fs::remove_file(test_file);
}
#[test]
fn test_mmap() {
use std::fs;
let test_file = "test_mmap_trie.marisa";
// Clean up from any previous test
let _ = fs::remove_file(test_file);
// Build and save a trie
let mut keyset = Keyset::new();
keyset.push("memory");
keyset.push("mapped");
keyset.push("trie");
let mut original_trie = Trie::new();
original_trie.build(&mut keyset).expect("Failed to build trie");
original_trie.save(test_file).expect("Failed to save trie");
// Memory-map the trie
let mut mmapped_trie = Trie::new();
mmapped_trie.mmap(test_file).expect("Failed to mmap trie");
// Verify the memory-mapped trie works
assert_eq!(mmapped_trie.size(), 3);
assert!(mmapped_trie.lookup("memory").is_some());
assert!(mmapped_trie.lookup("mapped").is_some());
assert!(mmapped_trie.lookup("trie").is_some());
assert!(mmapped_trie.lookup("nonexistent").is_none());
// Clean up
let _ = fs::remove_file(test_file);
}
#[test]
fn test_clear() {
let mut keyset = Keyset::new();
keyset.push("test");
let mut trie = Trie::new();
trie.build(&mut keyset).expect("Failed to build trie");
assert_eq!(trie.size(), 1);
assert!(!trie.is_empty());
// Clear the trie
trie.clear().expect("Failed to clear trie");
assert_eq!(trie.size(), 0);
assert!(trie.is_empty());
assert!(trie.lookup("test").is_none());
}
} }

View file

@ -47,4 +47,124 @@ fn main() {
trie.predictive_search("app", |key, id| { trie.predictive_search("app", |key, id| {
println!(" Found: '{}' (ID: {})", key, id); println!(" Found: '{}' (ID: {})", key, id);
}); });
test_save_load_functionality(&trie);
}
fn test_save_load_functionality(original_trie: &Trie) {
println!("\n=== Save/Load Functionality Test ===");
let test_file = "demo_trie.marisa";
// Test io_size
let trie_size = original_trie.io_size();
println!("Original trie serialized size: {} bytes", trie_size);
// Test save
println!("\nSaving trie to '{}'...", test_file);
match original_trie.save(test_file) {
Ok(()) => println!("✓ Trie saved successfully!"),
Err(e) => {
println!("✗ Failed to save trie: {}", e);
return;
}
}
// Test load
println!("\nLoading trie from '{}'...", test_file);
let mut loaded_trie = Trie::new();
match loaded_trie.load(test_file) {
Ok(()) => println!("✓ Trie loaded successfully! Size: {}", loaded_trie.size()),
Err(e) => {
println!("✗ Failed to load trie: {}", e);
return;
}
}
// Verify loaded trie functionality
println!("\n--- Verifying Loaded Trie ---");
let test_words = vec!["apple", "application", "banana"];
for word in &test_words {
let original_result = original_trie.lookup(word);
let loaded_result = loaded_trie.lookup(word);
match (original_result, loaded_result) {
(Some(orig_id), Some(load_id)) if orig_id == load_id => {
println!("✓ '{}' lookup matches (ID: {})", word, orig_id);
}
(None, None) => {
println!("✓ '{}' not found in both tries", word);
}
_ => {
println!("✗ '{}' lookup mismatch: orig={:?}, loaded={:?}",
word, original_result, loaded_result);
}
}
}
// Test memory mapping
println!("\n--- Testing Memory Mapping ---");
let mut mmapped_trie = Trie::new();
match mmapped_trie.mmap(test_file) {
Ok(()) => {
println!("✓ Trie memory-mapped successfully! Size: {}", mmapped_trie.size());
// Test a lookup on memory-mapped trie
if let Some(id) = mmapped_trie.lookup("apple") {
match mmapped_trie.reverse_lookup(id) {
Ok(word) => println!("✓ Memory-mapped reverse lookup: ID {} -> '{}'", id, word),
Err(e) => println!("✗ Memory-mapped reverse lookup failed: {}", e),
}
}
}
Err(e) => println!("✗ Failed to memory-map trie: {}", e),
}
// Test clear functionality
println!("\n--- Testing Clear Functionality ---");
let mut test_trie = Trie::new();
let mut keyset = Keyset::new();
keyset.push("test");
keyset.push("clear");
if test_trie.build(&mut keyset).is_ok() {
println!("Created test trie with {} keys", test_trie.size());
// Test lookup before clear
let before_lookup = test_trie.lookup("test").is_some();
println!("Before clear - 'test' found: {}", before_lookup);
// Clear the trie
match test_trie.clear() {
Ok(()) => println!("✓ Trie clear operation successful"),
Err(e) => {
println!("✗ Failed to clear trie: {}", e);
return;
}
}
// Test after clear
println!("Testing trie state after clear...");
let after_size = test_trie.size();
let is_empty = test_trie.is_empty();
let after_lookup = test_trie.lookup("test").is_some();
println!(" Size after clear: {}", after_size);
println!(" Is empty after clear: {}", is_empty);
println!(" 'test' found after clear: {}", after_lookup);
if after_size == 0 && is_empty && !after_lookup {
println!("✓ Clear functionality works correctly!");
} else {
println!("⚠ Clear may not have worked as expected");
}
} else {
println!("✗ Failed to build test trie for clear functionality");
}
// Clean up
match std::fs::remove_file(test_file) {
Ok(()) => println!("\n✓ Cleanup: Removed test file '{}'", test_file),
Err(e) => println!("\n⚠ Warning: Could not remove test file '{}': {}", test_file, e),
}
} }

View file

@ -82,9 +82,64 @@ int marisa_trie_predictive_search(const MarisaTrie* trie, MarisaAgent* agent) {
} }
} }
size_t marisa_trie_size(const MarisaTrie* trie) { int marisa_trie_size(const MarisaTrie* trie, size_t* size) {
const marisa::Trie* tr = reinterpret_cast<const marisa::Trie*>(trie); try {
return tr->size(); const marisa::Trie* tr = reinterpret_cast<const marisa::Trie*>(trie);
*size = tr->size();
return 1;
} catch (...) {
return 0;
}
}
int marisa_trie_save(const MarisaTrie* trie, const char* filename) {
try {
const marisa::Trie* tr = reinterpret_cast<const marisa::Trie*>(trie);
tr->save(filename);
return 1;
} catch (...) {
return 0;
}
}
int marisa_trie_load(MarisaTrie* trie, const char* filename) {
try {
marisa::Trie* tr = reinterpret_cast<marisa::Trie*>(trie);
tr->load(filename);
return 1;
} catch (...) {
return 0;
}
}
int marisa_trie_mmap(MarisaTrie* trie, const char* filename) {
try {
marisa::Trie* tr = reinterpret_cast<marisa::Trie*>(trie);
tr->mmap(filename);
return 1;
} catch (...) {
return 0;
}
}
int marisa_trie_io_size(const MarisaTrie* trie, size_t* size) {
try {
const marisa::Trie* tr = reinterpret_cast<const marisa::Trie*>(trie);
*size = tr->io_size();
return 1;
} catch (...) {
return 0;
}
}
int marisa_trie_clear(MarisaTrie* trie) {
try {
marisa::Trie* tr = reinterpret_cast<marisa::Trie*>(trie);
tr->clear();
return 1;
} catch (...) {
return 0;
}
} }
MarisaAgent* marisa_agent_new() { MarisaAgent* marisa_agent_new() {

View file

@ -25,7 +25,12 @@ int marisa_trie_lookup(const MarisaTrie* trie, MarisaAgent* agent);
int marisa_trie_reverse_lookup(const MarisaTrie* trie, MarisaAgent* agent); int marisa_trie_reverse_lookup(const MarisaTrie* trie, MarisaAgent* agent);
int marisa_trie_common_prefix_search(const MarisaTrie* trie, MarisaAgent* agent); int marisa_trie_common_prefix_search(const MarisaTrie* trie, MarisaAgent* agent);
int marisa_trie_predictive_search(const MarisaTrie* trie, MarisaAgent* agent); int marisa_trie_predictive_search(const MarisaTrie* trie, MarisaAgent* agent);
size_t marisa_trie_size(const MarisaTrie* trie); int marisa_trie_size(const MarisaTrie* trie, size_t* size);
int marisa_trie_save(const MarisaTrie* trie, const char* filename);
int marisa_trie_load(MarisaTrie* trie, const char* filename);
int marisa_trie_mmap(MarisaTrie* trie, const char* filename);
int marisa_trie_io_size(const MarisaTrie* trie, size_t* size);
int marisa_trie_clear(MarisaTrie* trie);
// Agent functions // Agent functions
MarisaAgent* marisa_agent_new(); MarisaAgent* marisa_agent_new();