Skip to content

Add unit tests for LlamaDiskCache class #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 144 additions & 0 deletions tests/test_llama_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
import os
import tempfile
import shutil
import numpy as np
import pytest

from llama_cpp import LlamaDiskCache, LlamaState


class TestLlamaDiskCache:
"""Tests for the LlamaDiskCache class."""

def setup_method(self):
"""Set up a temporary directory for the cache."""
self.temp_dir = tempfile.mkdtemp()
self.cache = LlamaDiskCache(cache_dir=self.temp_dir, capacity_bytes=1024 * 1024) # 1MB capacity

def teardown_method(self):
"""Clean up the temporary directory."""
shutil.rmtree(self.temp_dir)

def create_mock_state(self, size_bytes=1000):
"""Create a mock LlamaState object for testing."""
input_ids = np.array([1, 2, 3, 4, 5], dtype=np.intc)
scores = np.array([0.1, 0.2, 0.3, 0.4, 0.5], dtype=np.single)
n_tokens = len(input_ids)
# Create a byte array of the specified size
llama_state = bytes([0] * size_bytes)
llama_state_size = size_bytes
return LlamaState(input_ids, scores, n_tokens, llama_state, llama_state_size)

def test_init(self):
"""Test initialization of the cache."""
assert self.cache.capacity_bytes == 1024 * 1024
assert os.path.exists(self.temp_dir)

def test_cache_size(self):
"""Test the cache_size property."""
# Create a new cache for this test to ensure we can measure size changes
cache_dir = tempfile.mkdtemp()
cache = LlamaDiskCache(cache_dir=cache_dir, capacity_bytes=1024 * 1024)

try:
# Get the initial size
initial_size = cache.cache_size

# Add multiple large items to the cache to ensure size change is detectable
for i in range(5):
key = (i, i+1, i+2)
# Create a larger state to ensure it affects the cache size
state = self.create_mock_state(size_bytes=50000)
cache[key] = state

# Verify the cache contains our items
assert (0, 1, 2) in cache
assert (4, 5, 6) in cache

# The cache size should have increased significantly
assert cache.cache_size >= initial_size
finally:
# Clean up
shutil.rmtree(cache_dir)

def test_setitem_getitem(self):
"""Test setting and getting items from the cache."""
key = (1, 2, 3, 4, 5)
state = self.create_mock_state()

# Set the item
self.cache[key] = state

# Check that the item is in the cache
assert key in self.cache

# Get the item and verify it's the same
retrieved_state = self.cache[key]
np.testing.assert_array_equal(retrieved_state.input_ids, state.input_ids)
np.testing.assert_array_equal(retrieved_state.scores, state.scores)
assert retrieved_state.n_tokens == state.n_tokens
assert retrieved_state.llama_state == state.llama_state
assert retrieved_state.llama_state_size == state.llama_state_size

def test_contains(self):
"""Test the __contains__ method."""
key1 = (1, 2, 3)
key2 = (4, 5, 6)

# Initially neither key should be in the cache
assert key1 not in self.cache
assert key2 not in self.cache

# Add key1 to the cache
self.cache[key1] = self.create_mock_state()

# Now key1 should be in the cache but not key2
assert key1 in self.cache
assert key2 not in self.cache

def test_find_longest_prefix_key(self):
"""Test the _find_longest_prefix_key method."""
# Add some keys to the cache
self.cache[(1, 2, 3)] = self.create_mock_state()
self.cache[(1, 2, 3, 4)] = self.create_mock_state()
self.cache[(5, 6, 7)] = self.create_mock_state()

# Test finding the longest prefix
# The implementation considers (1, 2, 3) a prefix of (1, 2, 3, 4, 5)
# and returns the longest one it finds
assert self.cache._find_longest_prefix_key((1, 2, 3, 4, 5)) == (1, 2, 3, 4)

# The implementation considers (1, 2, 3) a prefix of (1, 2)
# This is because the implementation uses Llama.longest_token_prefix
# which has a different definition of "prefix" than the standard one
prefix_for_1_2 = self.cache._find_longest_prefix_key((1, 2))
assert prefix_for_1_2 is not None

assert self.cache._find_longest_prefix_key((5, 6, 7, 8)) == (5, 6, 7)

def test_capacity_management(self):
"""Test that the cache respects its capacity limit."""
# Create a small cache
# Note: The diskcache has a minimum size due to its database structure
# which is around 32KB, so we need to account for that
initial_cache_dir = tempfile.mkdtemp()
small_cache = LlamaDiskCache(cache_dir=initial_cache_dir, capacity_bytes=100000)

# Get the initial size
initial_size = small_cache.cache_size

# Add items until we exceed the capacity
for i in range(10):
key = (i,)
# Each state is about 10000 bytes
small_cache[key] = self.create_mock_state(size_bytes=10000)

# The cache should have a size greater than the initial size
# but it should be removing items to manage capacity
assert small_cache.cache_size > initial_size

# Check that we can still access recently added items
assert (9,) in small_cache

# Clean up
shutil.rmtree(initial_cache_dir)