Skip to content

Commit b4a8a0c

Browse files
committed
Use ctypes_extensions module for libllama and libllava
1 parent eb16072 commit b4a8a0c

File tree

2 files changed

+28
-225
lines changed

2 files changed

+28
-225
lines changed

llama_cpp/llama_cpp.py

Lines changed: 17 additions & 123 deletions
Original file line numberDiff line numberDiff line change
@@ -1,151 +1,45 @@
11
from __future__ import annotations
22

3-
import sys
43
import os
54
import ctypes
6-
import functools
75
import pathlib
86

97
from typing import (
10-
Any,
118
Callable,
12-
List,
139
Union,
1410
NewType,
1511
Optional,
1612
TYPE_CHECKING,
17-
TypeVar,
18-
Generic,
1913
)
20-
from typing_extensions import TypeAlias
2114

15+
from llama_cpp._ctypes_extensions import (
16+
load_shared_library,
17+
byref,
18+
ctypes_function_for_shared_library,
19+
)
2220

23-
# Load the library
24-
def _load_shared_library(lib_base_name: str):
25-
# Construct the paths to the possible shared library names
26-
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
27-
# Searching for the library in the current directory under the name "libllama" (default name
28-
# for llamacpp) and "llama" (default name for this repo)
29-
_lib_paths: List[pathlib.Path] = []
30-
# Determine the file extension based on the platform
31-
if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
32-
_lib_paths += [
33-
_base_path / f"lib{lib_base_name}.so",
34-
]
35-
elif sys.platform == "darwin":
36-
_lib_paths += [
37-
_base_path / f"lib{lib_base_name}.so",
38-
_base_path / f"lib{lib_base_name}.dylib",
39-
]
40-
elif sys.platform == "win32":
41-
_lib_paths += [
42-
_base_path / f"{lib_base_name}.dll",
43-
_base_path / f"lib{lib_base_name}.dll",
44-
]
45-
else:
46-
raise RuntimeError("Unsupported platform")
47-
48-
if "LLAMA_CPP_LIB" in os.environ:
49-
lib_base_name = os.environ["LLAMA_CPP_LIB"]
50-
_lib = pathlib.Path(lib_base_name)
51-
_base_path = _lib.parent.resolve()
52-
_lib_paths = [_lib.resolve()]
53-
54-
cdll_args = dict() # type: ignore
55-
56-
# Add the library directory to the DLL search path on Windows (if needed)
57-
if sys.platform == "win32":
58-
os.add_dll_directory(str(_base_path))
59-
os.environ["PATH"] = str(_base_path) + os.pathsep + os.environ["PATH"]
60-
61-
if sys.platform == "win32" and sys.version_info >= (3, 8):
62-
os.add_dll_directory(str(_base_path))
63-
if "CUDA_PATH" in os.environ:
64-
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
65-
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
66-
if "HIP_PATH" in os.environ:
67-
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "bin"))
68-
os.add_dll_directory(os.path.join(os.environ["HIP_PATH"], "lib"))
69-
cdll_args["winmode"] = ctypes.RTLD_GLOBAL
70-
71-
# Try to load the shared library, handling potential errors
72-
for _lib_path in _lib_paths:
73-
if _lib_path.exists():
74-
try:
75-
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
76-
except Exception as e:
77-
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
78-
79-
raise FileNotFoundError(
80-
f"Shared library with base name '{lib_base_name}' not found"
21+
if TYPE_CHECKING:
22+
from llama_cpp._ctypes_extensions import (
23+
CtypesCData,
24+
CtypesArray,
25+
CtypesPointer,
26+
CtypesVoidPointer,
27+
CtypesRef,
28+
CtypesPointerOrRef,
29+
CtypesFuncPointer,
8130
)
8231

8332

8433
# Specify the base name of the shared library to load
8534
_lib_base_name = "llama"
86-
35+
_override_base_path = os.environ.get("LLAMA_CPP_LIB_PATH")
36+
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _override_base_path is None else pathlib.Path(_override_base_path)
8737
# Load the library
88-
_lib = _load_shared_library(_lib_base_name)
89-
90-
91-
# ctypes sane type hint helpers
92-
#
93-
# - Generic Pointer and Array types
94-
# - PointerOrRef type with a type hinted byref function
95-
#
96-
# NOTE: Only use these for static type checking not for runtime checks
97-
# no good will come of that
98-
99-
if TYPE_CHECKING:
100-
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
101-
102-
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
103-
104-
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
105-
106-
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
107-
108-
class CtypesRef(Generic[CtypesCData]):
109-
pass
110-
111-
CtypesPointerOrRef: TypeAlias = Union[
112-
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
113-
]
114-
115-
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
116-
117-
F = TypeVar("F", bound=Callable[..., Any])
118-
119-
120-
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
121-
def ctypes_function(
122-
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
123-
):
124-
def decorator(f: F) -> F:
125-
if enabled:
126-
func = getattr(lib, name)
127-
func.argtypes = argtypes
128-
func.restype = restype
129-
functools.wraps(f)(func)
130-
return func
131-
else:
132-
return f
133-
134-
return decorator
135-
136-
return ctypes_function
137-
38+
_lib = load_shared_library(_lib_base_name, _base_path)
13839

13940
ctypes_function = ctypes_function_for_shared_library(_lib)
14041

14142

142-
def byref(obj: CtypesCData, offset: Optional[int] = None) -> CtypesRef[CtypesCData]:
143-
"""Type-annotated version of ctypes.byref"""
144-
...
145-
146-
147-
byref = ctypes.byref # type: ignore
148-
14943
# from ggml.h
15044
# // NOTE: always add types at the end of the enum to keep backward compatibility
15145
# enum ggml_type {

llama_cpp/llava_cpp.py

Lines changed: 11 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
from __future__ import annotations
22

3-
import sys
43
import os
5-
import ctypes
6-
import functools
74
from ctypes import (
85
c_bool,
96
c_char_p,
@@ -17,121 +14,32 @@
1714
)
1815
import pathlib
1916
from typing import (
20-
List,
2117
Union,
2218
NewType,
2319
Optional,
24-
TypeVar,
25-
Callable,
26-
Any,
2720
TYPE_CHECKING,
28-
Generic,
2921
)
30-
from typing_extensions import TypeAlias
3122

3223
import llama_cpp.llama_cpp as llama_cpp
3324

25+
from llama_cpp._ctypes_extensions import (
26+
load_shared_library,
27+
ctypes_function_for_shared_library,
28+
)
3429

35-
# Load the library
36-
def _load_shared_library(lib_base_name: str):
37-
# Construct the paths to the possible shared library names
38-
_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib"
39-
# Searching for the library in the current directory under the name "libllama" (default name
40-
# for llamacpp) and "llama" (default name for this repo)
41-
_lib_paths: List[pathlib.Path] = []
42-
# Determine the file extension based on the platform
43-
if sys.platform.startswith("linux"):
44-
_lib_paths += [
45-
_base_path / f"lib{lib_base_name}.so",
46-
]
47-
elif sys.platform == "darwin":
48-
_lib_paths += [
49-
_base_path / f"lib{lib_base_name}.so",
50-
_base_path / f"lib{lib_base_name}.dylib",
51-
]
52-
elif sys.platform == "win32":
53-
_lib_paths += [
54-
_base_path / f"{lib_base_name}.dll",
55-
_base_path / f"lib{lib_base_name}.dll",
56-
]
57-
else:
58-
raise RuntimeError("Unsupported platform")
59-
60-
if "LLAVA_CPP_LIB" in os.environ:
61-
lib_base_name = os.environ["LLAVA_CPP_LIB"]
62-
_lib = pathlib.Path(lib_base_name)
63-
_base_path = _lib.parent.resolve()
64-
_lib_paths = [_lib.resolve()]
65-
66-
cdll_args = dict() # type: ignore
67-
# Add the library directory to the DLL search path on Windows (if needed)
68-
if sys.platform == "win32" and sys.version_info >= (3, 8):
69-
os.add_dll_directory(str(_base_path))
70-
if "CUDA_PATH" in os.environ:
71-
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "bin"))
72-
os.add_dll_directory(os.path.join(os.environ["CUDA_PATH"], "lib"))
73-
cdll_args["winmode"] = ctypes.RTLD_GLOBAL
74-
75-
# Try to load the shared library, handling potential errors
76-
for _lib_path in _lib_paths:
77-
if _lib_path.exists():
78-
try:
79-
return ctypes.CDLL(str(_lib_path), **cdll_args) # type: ignore
80-
except Exception as e:
81-
raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")
82-
83-
raise FileNotFoundError(
84-
f"Shared library with base name '{lib_base_name}' not found"
30+
if TYPE_CHECKING:
31+
from llama_cpp._ctypes_extensions import (
32+
CtypesArray,
8533
)
8634

8735

8836
# Specify the base name of the shared library to load
8937
_libllava_base_name = "llava"
38+
_libllava_override_path = os.environ.get("LLAVA_CPP_LIB")
39+
_libllava_base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__))) / "lib" if _libllava_override_path is None else pathlib.Path()
9040

9141
# Load the library
92-
_libllava = _load_shared_library(_libllava_base_name)
93-
94-
# ctypes helper
95-
96-
if TYPE_CHECKING:
97-
CtypesCData = TypeVar("CtypesCData", bound=ctypes._CData) # type: ignore
98-
99-
CtypesArray: TypeAlias = ctypes.Array[CtypesCData] # type: ignore
100-
101-
CtypesPointer: TypeAlias = ctypes._Pointer[CtypesCData] # type: ignore
102-
103-
CtypesVoidPointer: TypeAlias = ctypes.c_void_p
104-
105-
class CtypesRef(Generic[CtypesCData]):
106-
pass
107-
108-
CtypesPointerOrRef: TypeAlias = Union[
109-
CtypesPointer[CtypesCData], CtypesRef[CtypesCData]
110-
]
111-
112-
CtypesFuncPointer: TypeAlias = ctypes._FuncPointer # type: ignore
113-
114-
F = TypeVar("F", bound=Callable[..., Any])
115-
116-
117-
def ctypes_function_for_shared_library(lib: ctypes.CDLL):
118-
def ctypes_function(
119-
name: str, argtypes: List[Any], restype: Any, enabled: bool = True
120-
):
121-
def decorator(f: F) -> F:
122-
if enabled:
123-
func = getattr(lib, name)
124-
func.argtypes = argtypes
125-
func.restype = restype
126-
functools.wraps(f)(func)
127-
return func
128-
else:
129-
return f
130-
131-
return decorator
132-
133-
return ctypes_function
134-
42+
_libllava = load_shared_library(_libllava_base_name, _libllava_base_path)
13543

13644
ctypes_function = ctypes_function_for_shared_library(_libllava)
13745

@@ -247,3 +155,4 @@ def clip_model_load(
247155
@ctypes_function("clip_free", [clip_ctx_p_ctypes], None)
248156
def clip_free(ctx: clip_ctx_p, /):
249157
...
158+

0 commit comments

Comments
 (0)