kuzu

Kùzu Python API bindings.

This package provides a Python API for Kùzu graph database management system.

To install the package, run:

python3 -m pip install kuzu

Example usage:

import kuzu

db = kuzu.Database("./test")
conn = kuzu.Connection(db)

# Define the schema
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)")

# Load some data
conn.execute('COPY User FROM "user.csv"')
conn.execute('COPY City FROM "city.csv"')
conn.execute('COPY Follows FROM "follows.csv"')
conn.execute('COPY LivesIn FROM "lives-in.csv"')

# Query the data
results = conn.execute("MATCH (u:User) RETURN u.name, u.age;")
while results.has_next():
    print(results.get_next())

The dataset used in this example can be found here.

 1"""
 2# Kùzu Python API bindings.
 3
 4This package provides a Python API for Kùzu graph database management system.
 5
 6To install the package, run:
 7```
 8python3 -m pip install kuzu
 9```
10
11Example usage:
12```python
13import kuzu
14
15db = kuzu.Database("./test")
16conn = kuzu.Connection(db)
17
18# Define the schema
19conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
20conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))")
21conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
22conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)")
23
24# Load some data
25conn.execute('COPY User FROM "user.csv"')
26conn.execute('COPY City FROM "city.csv"')
27conn.execute('COPY Follows FROM "follows.csv"')
28conn.execute('COPY LivesIn FROM "lives-in.csv"')
29
30# Query the data
31results = conn.execute("MATCH (u:User) RETURN u.name, u.age;")
32while results.has_next():
33    print(results.get_next())
34```
35
36The dataset used in this example can be found [here](https://github.com/kuzudb/kuzu/tree/master/dataset/demo-db/csv).
37
38"""
39
40from __future__ import annotations
41
42import os
43import sys
44
45# Set RTLD_GLOBAL and RTLD_LAZY flags on Linux to fix the issue with loading
46# extensions
47if sys.platform == "linux":
48    original_dlopen_flags = sys.getdlopenflags()
49    sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY)
50
51from .connection import Connection
52from .database import Database
53from .prepared_statement import PreparedStatement
54from .query_result import QueryResult
55from .types import Type
56
57
58def __getattr__(name: str) -> str | int:
59    if name in ("version", "__version__"):
60        return Database.get_version()
61    elif name == "storage_version":
62        return Database.get_storage_version()
63    else:
64        msg = f"module {__name__!r} has no attribute {name!r}"
65        raise AttributeError(msg)
66
67
68# Restore the original dlopen flags
69if sys.platform == "linux":
70    sys.setdlopenflags(original_dlopen_flags)
71
72__all__ = [
73    "Connection",
74    "Database",
75    "PreparedStatement",
76    "QueryResult",
77    "Type",
78    "__version__",
79    "storage_version",
80    "version",
81]
class Connection:
 23class Connection:
 24    """Connection to a database."""
 25
 26    def __init__(self, database: Database, num_threads: int = 0):
 27        """
 28        Initialise kuzu database connection.
 29
 30        Parameters
 31        ----------
 32        database : Database
 33            Database to connect to.
 34
 35        num_threads : int
 36            Maximum number of threads to use for executing queries.
 37
 38        """
 39        self._connection: Any = None  # (type: _kuzu.Connection from pybind11)
 40        self.database = database
 41        self.num_threads = num_threads
 42        self.is_closed = False
 43        self.init_connection()
 44
 45    def __getstate__(self) -> dict[str, Any]:
 46        state = {
 47            "database": self.database,
 48            "num_threads": self.num_threads,
 49            "_connection": None,
 50        }
 51        return state
 52
 53    def init_connection(self) -> None:
 54        """Establish a connection to the database, if not already initalised."""
 55        if self.is_closed:
 56            error_msg = "Connection is closed."
 57            raise RuntimeError(error_msg)
 58        self.database.init_database()
 59        if self._connection is None:
 60            self._connection = _kuzu.Connection(self.database._database, self.num_threads)  # type: ignore[union-attr]
 61
 62    def set_max_threads_for_exec(self, num_threads: int) -> None:
 63        """
 64        Set the maximum number of threads for executing queries.
 65
 66        Parameters
 67        ----------
 68        num_threads : int
 69            Maximum number of threads to use for executing queries.
 70
 71        """
 72        self.init_connection()
 73        self._connection.set_max_threads_for_exec(num_threads)
 74
 75    def close(self) -> None:
 76        """
 77        Close the connection.
 78
 79        Note: Call to this method is optional. The connection will be closed
 80        automatically when the object goes out of scope.
 81        """
 82        if self._connection is not None:
 83            self._connection.close()
 84        self._connection = None
 85        self.is_closed = True
 86
 87    def __enter__(self) -> Self:
 88        return self
 89
 90    def __exit__(
 91        self,
 92        exc_type: type[BaseException] | None,
 93        exc_value: BaseException | None,
 94        exc_traceback: TracebackType | None,
 95    ) -> None:
 96        self.close()
 97
 98    def execute(
 99        self,
100        query: str | PreparedStatement,
101        parameters: dict[str, Any] | None = None,
102    ) -> QueryResult | list[QueryResult]:
103        """
104        Execute a query.
105
106        Parameters
107        ----------
108        query : str | PreparedStatement
109            A prepared statement or a query string.
110            If a query string is given, a prepared statement will be created
111            automatically.
112
113        parameters : dict[str, Any]
114            Parameters for the query.
115
116        Returns
117        -------
118        QueryResult
119            Query result.
120
121        """
122        if parameters is None:
123            parameters = {}
124
125        self.init_connection()
126        if not isinstance(parameters, dict):
127            msg = f"Parameters must be a dict; found {type(parameters)}."
128            raise RuntimeError(msg)  # noqa: TRY004
129
130        if len(parameters) == 0:
131            _query_result = self._connection.query(query)
132        else:
133            prepared_statement = self.prepare(query) if isinstance(query, str) else query
134            _query_result = self._connection.execute(prepared_statement._prepared_statement, parameters)
135        if not _query_result.isSuccess():
136            raise RuntimeError(_query_result.getErrorMessage())
137        current_query_result = QueryResult(self, _query_result)
138        if not _query_result.hasNextQueryResult():
139            return current_query_result
140        all_query_results = [current_query_result]
141        while _query_result.hasNextQueryResult():
142            _query_result = _query_result.getNextQueryResult()
143            if not _query_result.isSuccess():
144                raise RuntimeError(_query_result.getErrorMessage())
145            all_query_results.append(QueryResult(self, _query_result))
146        return all_query_results
147
148    def prepare(self, query: str) -> PreparedStatement:
149        """
150        Create a prepared statement for a query.
151
152        Parameters
153        ----------
154        query : str
155            Query to prepare.
156
157        Returns
158        -------
159        PreparedStatement
160            Prepared statement.
161
162        """
163        return PreparedStatement(self, query)
164
165    def _get_node_property_names(self, table_name: str) -> dict[str, Any]:
166        LIST_START_SYMBOL = "["
167        LIST_END_SYMBOL = "]"
168        self.init_connection()
169        query_result = self.execute(f"CALL table_info('{table_name}') RETURN *;")
170        results = {}
171        while query_result.has_next():
172            row = query_result.get_next()
173            prop_name = row[1]
174            prop_type = row[2]
175            is_primary_key = row[4] is True
176            dimension = prop_type.count(LIST_START_SYMBOL)
177            splitted = prop_type.split(LIST_START_SYMBOL)
178            shape = []
179            for s in splitted:
180                if LIST_END_SYMBOL not in s:
181                    continue
182                s = s.split(LIST_END_SYMBOL)[0]
183                if s != "":
184                    shape.append(int(s))
185            prop_type = splitted[0]
186            results[prop_name] = {
187                "type": prop_type,
188                "dimension": dimension,
189                "is_primary_key": is_primary_key,
190            }
191            if len(shape) > 0:
192                results[prop_name]["shape"] = tuple(shape)
193        return results
194
195    def _get_node_table_names(self) -> list[Any]:
196        results = []
197        self.init_connection()
198        query_result = self.execute("CALL show_tables() RETURN *;")
199        while query_result.has_next():
200            row = query_result.get_next()
201            if row[2] == "NODE":
202                results.append(row[1])
203        return results
204
205    def _get_rel_table_names(self) -> list[dict[str, Any]]:
206        results = []
207        self.init_connection()
208        tables_result = self.execute("CALL show_tables() RETURN *;")
209        while tables_result.has_next():
210            row = tables_result.get_next()
211            if row[2] == "REL":
212                name = row[1]
213                connections_result = self.execute(f"CALL show_connection({name!r}) RETURN *;")
214                src_dst_row = connections_result.get_next()
215                src_node = src_dst_row[0]
216                dst_node = src_dst_row[1]
217                results.append({"name": name, "src": src_node, "dst": dst_node})
218        return results
219
220    def set_query_timeout(self, timeout_in_ms: int) -> None:
221        """
222        Set the query timeout value in ms for executing queries.
223
224        Parameters
225        ----------
226        timeout_in_ms : int
227            query timeout value in ms for executing queries.
228
229        """
230        self.init_connection()
231        self._connection.set_query_timeout(timeout_in_ms)
232
233    def create_function(
234        self,
235        name: str,
236        udf: Callable[[...], Any],
237        params_type: list[Type | str] | None = None,
238        return_type: Type | str = "",
239        *,
240        default_null_handling: bool = True,
241        catch_exceptions: bool = False,
242    ) -> None:
243        """
244        Sets a User Defined Function (UDF) to use in cypher queries.
245
246        Parameters
247        ----------
248        name: str
249            name of function
250
251        udf: Callable[[...], Any]
252            function to be executed
253
254        params_type: Optional[list[Type]]
255            list of Type enums to describe the input parameters
256
257        return_type: Optional[Type]
258            a Type enum to describe the returned value
259
260        default_null_handling: Optional[bool]
261            if true, when any parameter is null, the resulting value will be null
262
263        catch_exceptions: Optional[bool]
264            if true, when an exception is thrown from python, the function output will be null
265            Otherwise, the exception will be rethrown
266        """
267        if params_type is None:
268            params_type = []
269        parsed_params_type = [x if type(x) is str else x.value for x in params_type]
270        if type(return_type) is not str:
271            return_type = return_type.value
272
273        self._connection.create_function(
274            name=name,
275            udf=udf,
276            params_type=parsed_params_type,
277            return_value=return_type,
278            default_null=default_null_handling,
279            catch_exceptions=catch_exceptions,
280        )
281
282    def remove_function(self, name: str) -> None:
283        """
284        Removes a User Defined Function (UDF).
285
286        Parameters
287        ----------
288        name: str
289            name of function to be removed.
290        """
291        self._connection.remove_function(name)

Connection to a database.

Connection(database: Database, num_threads: int = 0)
26    def __init__(self, database: Database, num_threads: int = 0):
27        """
28        Initialise kuzu database connection.
29
30        Parameters
31        ----------
32        database : Database
33            Database to connect to.
34
35        num_threads : int
36            Maximum number of threads to use for executing queries.
37
38        """
39        self._connection: Any = None  # (type: _kuzu.Connection from pybind11)
40        self.database = database
41        self.num_threads = num_threads
42        self.is_closed = False
43        self.init_connection()

Initialise kuzu database connection.

Parameters
  • database (Database): Database to connect to.
  • num_threads (int): Maximum number of threads to use for executing queries.
database
num_threads
is_closed
def init_connection(self) -> None:
53    def init_connection(self) -> None:
54        """Establish a connection to the database, if not already initalised."""
55        if self.is_closed:
56            error_msg = "Connection is closed."
57            raise RuntimeError(error_msg)
58        self.database.init_database()
59        if self._connection is None:
60            self._connection = _kuzu.Connection(self.database._database, self.num_threads)  # type: ignore[union-attr]

Establish a connection to the database, if not already initalised.

def set_max_threads_for_exec(self, num_threads: int) -> None:
62    def set_max_threads_for_exec(self, num_threads: int) -> None:
63        """
64        Set the maximum number of threads for executing queries.
65
66        Parameters
67        ----------
68        num_threads : int
69            Maximum number of threads to use for executing queries.
70
71        """
72        self.init_connection()
73        self._connection.set_max_threads_for_exec(num_threads)

Set the maximum number of threads for executing queries.

Parameters
  • num_threads (int): Maximum number of threads to use for executing queries.
def close(self) -> None:
75    def close(self) -> None:
76        """
77        Close the connection.
78
79        Note: Call to this method is optional. The connection will be closed
80        automatically when the object goes out of scope.
81        """
82        if self._connection is not None:
83            self._connection.close()
84        self._connection = None
85        self.is_closed = True

Close the connection.

Note: Call to this method is optional. The connection will be closed automatically when the object goes out of scope.

def execute( self, query: str | PreparedStatement, parameters: dict[str, typing.Any] | None = None) -> QueryResult | list[QueryResult]:
 98    def execute(
 99        self,
100        query: str | PreparedStatement,
101        parameters: dict[str, Any] | None = None,
102    ) -> QueryResult | list[QueryResult]:
103        """
104        Execute a query.
105
106        Parameters
107        ----------
108        query : str | PreparedStatement
109            A prepared statement or a query string.
110            If a query string is given, a prepared statement will be created
111            automatically.
112
113        parameters : dict[str, Any]
114            Parameters for the query.
115
116        Returns
117        -------
118        QueryResult
119            Query result.
120
121        """
122        if parameters is None:
123            parameters = {}
124
125        self.init_connection()
126        if not isinstance(parameters, dict):
127            msg = f"Parameters must be a dict; found {type(parameters)}."
128            raise RuntimeError(msg)  # noqa: TRY004
129
130        if len(parameters) == 0:
131            _query_result = self._connection.query(query)
132        else:
133            prepared_statement = self.prepare(query) if isinstance(query, str) else query
134            _query_result = self._connection.execute(prepared_statement._prepared_statement, parameters)
135        if not _query_result.isSuccess():
136            raise RuntimeError(_query_result.getErrorMessage())
137        current_query_result = QueryResult(self, _query_result)
138        if not _query_result.hasNextQueryResult():
139            return current_query_result
140        all_query_results = [current_query_result]
141        while _query_result.hasNextQueryResult():
142            _query_result = _query_result.getNextQueryResult()
143            if not _query_result.isSuccess():
144                raise RuntimeError(_query_result.getErrorMessage())
145            all_query_results.append(QueryResult(self, _query_result))
146        return all_query_results

Execute a query.

Parameters
  • query (str | PreparedStatement): A prepared statement or a query string. If a query string is given, a prepared statement will be created automatically.
  • parameters (dict[str, Any]): Parameters for the query.
Returns
  • QueryResult: Query result.
def prepare(self, query: str) -> PreparedStatement:
148    def prepare(self, query: str) -> PreparedStatement:
149        """
150        Create a prepared statement for a query.
151
152        Parameters
153        ----------
154        query : str
155            Query to prepare.
156
157        Returns
158        -------
159        PreparedStatement
160            Prepared statement.
161
162        """
163        return PreparedStatement(self, query)

Create a prepared statement for a query.

Parameters
  • query (str): Query to prepare.
Returns
  • PreparedStatement: Prepared statement.
def set_query_timeout(self, timeout_in_ms: int) -> None:
220    def set_query_timeout(self, timeout_in_ms: int) -> None:
221        """
222        Set the query timeout value in ms for executing queries.
223
224        Parameters
225        ----------
226        timeout_in_ms : int
227            query timeout value in ms for executing queries.
228
229        """
230        self.init_connection()
231        self._connection.set_query_timeout(timeout_in_ms)

Set the query timeout value in ms for executing queries.

Parameters
  • timeout_in_ms (int): query timeout value in ms for executing queries.
def create_function( self, name: str, udf: Callable[..., Any], params_type: list[Type | str] | None = None, return_type: Type | str = '', *, default_null_handling: bool = True, catch_exceptions: bool = False) -> None:
233    def create_function(
234        self,
235        name: str,
236        udf: Callable[[...], Any],
237        params_type: list[Type | str] | None = None,
238        return_type: Type | str = "",
239        *,
240        default_null_handling: bool = True,
241        catch_exceptions: bool = False,
242    ) -> None:
243        """
244        Sets a User Defined Function (UDF) to use in cypher queries.
245
246        Parameters
247        ----------
248        name: str
249            name of function
250
251        udf: Callable[[...], Any]
252            function to be executed
253
254        params_type: Optional[list[Type]]
255            list of Type enums to describe the input parameters
256
257        return_type: Optional[Type]
258            a Type enum to describe the returned value
259
260        default_null_handling: Optional[bool]
261            if true, when any parameter is null, the resulting value will be null
262
263        catch_exceptions: Optional[bool]
264            if true, when an exception is thrown from python, the function output will be null
265            Otherwise, the exception will be rethrown
266        """
267        if params_type is None:
268            params_type = []
269        parsed_params_type = [x if type(x) is str else x.value for x in params_type]
270        if type(return_type) is not str:
271            return_type = return_type.value
272
273        self._connection.create_function(
274            name=name,
275            udf=udf,
276            params_type=parsed_params_type,
277            return_value=return_type,
278            default_null=default_null_handling,
279            catch_exceptions=catch_exceptions,
280        )

Sets a User Defined Function (UDF) to use in cypher queries.

Parameters
  • name (str): name of function
  • udf (Callable[[...], Any]): function to be executed
  • params_type (Optional[list[Type]]): list of Type enums to describe the input parameters
  • return_type (Optional[Type]): a Type enum to describe the returned value
  • default_null_handling (Optional[bool]): if true, when any parameter is null, the resulting value will be null
  • catch_exceptions (Optional[bool]): if true, when an exception is thrown from python, the function output will be null Otherwise, the exception will be rethrown
def remove_function(self, name: str) -> None:
282    def remove_function(self, name: str) -> None:
283        """
284        Removes a User Defined Function (UDF).
285
286        Parameters
287        ----------
288        name: str
289            name of function to be removed.
290        """
291        self._connection.remove_function(name)

Removes a User Defined Function (UDF).

Parameters
  • name (str): name of function to be removed.
class Database:
 26class Database:
 27    """Kùzu database instance."""
 28
 29    def __init__(
 30        self,
 31        database_path: str | Path | None = None,
 32        *,
 33        buffer_pool_size: int = 0,
 34        max_num_threads: int = 0,
 35        compression: bool = True,
 36        lazy_init: bool = False,
 37        read_only: bool = False,
 38        max_db_size: int = (1 << 43),
 39    ):
 40        """
 41        Parameters
 42        ----------
 43        database_path : str, Path
 44            The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database
 45            will be created in memory.
 46
 47        buffer_pool_size : int
 48            The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
 49
 50        max_num_threads : int
 51            The maximum number of threads to use for executing queries.
 52
 53        compression : bool
 54            Enable database compression.
 55
 56        lazy_init : bool
 57            If True, the database will not be initialized until the first query.
 58            This is useful when the database is not used in the main thread or
 59            when the main process is forked.
 60            Default to False.
 61
 62        read_only : bool
 63            If true, the database is opened read-only. No write transactions is
 64            allowed on the `Database` object. Multiple read-only `Database`
 65            objects can be created with the same database path. However, there
 66            cannot be multiple `Database` objects created with the same
 67            database path.
 68            Default to False.
 69
 70        max_db_size : int
 71            The maximum size of the database in bytes. Note that this is introduced
 72            temporarily for now to get around with the default 8TB mmap address
 73             space limit some environment. This will be removed once we implemente
 74             a better solution later. The value is default to 1 << 43 (8TB) under 64-bit
 75             environment and 1GB under 32-bit one.
 76
 77        """
 78        if database_path is None:
 79            database_path = ":memory:"
 80        if isinstance(database_path, Path):
 81            database_path = str(database_path)
 82
 83        self.database_path = database_path
 84        self.buffer_pool_size = buffer_pool_size
 85        self.max_num_threads = max_num_threads
 86        self.compression = compression
 87        self.read_only = read_only
 88        self.max_db_size = max_db_size
 89        self.is_closed = False
 90
 91        self._database: Any = None  # (type: _kuzu.Database from pybind11)
 92        if not lazy_init:
 93            self.init_database()
 94
 95    def __enter__(self) -> Self:
 96        return self
 97
 98    def __exit__(
 99        self,
100        exc_type: type[BaseException] | None,
101        exc_value: BaseException | None,
102        exc_traceback: TracebackType | None,
103    ) -> None:
104        self.close()
105
106    @staticmethod
107    def get_version() -> str:
108        """
109        Get the version of the database.
110
111        Returns
112        -------
113        str
114            The version of the database.
115        """
116        return _kuzu.Database.get_version()  # type: ignore[union-attr]
117
118    @staticmethod
119    def get_storage_version() -> int:
120        """
121        Get the storage version of the database.
122
123        Returns
124        -------
125        int
126            The storage version of the database.
127        """
128        return _kuzu.Database.get_storage_version()  # type: ignore[union-attr]
129
130    def __getstate__(self) -> dict[str, Any]:
131        state = {
132            "database_path": self.database_path,
133            "buffer_pool_size": self.buffer_pool_size,
134            "compression": self.compression,
135            "read_only": self.read_only,
136            "_database": None,
137        }
138        return state
139
140    def init_database(self) -> None:
141        """Initialize the database."""
142        self.check_for_database_close()
143        if self._database is None:
144            self._database = _kuzu.Database(  # type: ignore[union-attr]
145                self.database_path,
146                self.buffer_pool_size,
147                self.max_num_threads,
148                self.compression,
149                self.read_only,
150                self.max_db_size,
151            )
152
153    def get_torch_geometric_remote_backend(
154        self, num_threads: int | None = None
155    ) -> tuple[KuzuFeatureStore, KuzuGraphStore]:
156        """
157        Use the database as the remote backend for torch_geometric.
158
159        For the interface of the remote backend, please refer to
160        https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html.
161        The current implementation is read-only and does not support edge
162        features. The IDs of the nodes are based on the internal IDs (i.e., node
163        offsets). For the remote node IDs to be consistent with the positions in
164        the output tensors, please ensure that no deletion has been performed
165        on the node tables.
166
167        The remote backend can also be plugged into the data loader of
168        torch_geometric, which is useful for mini-batch training. For example:
169
170        ```python
171            loader_kuzu = NeighborLoader(
172                data=(feature_store, graph_store),
173                num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
174                batch_size=LOADER_BATCH_SIZE,
175                input_nodes=('paper', input_nodes),
176                num_workers=4,
177                filter_per_worker=False,
178            )
179        ```
180
181        Please note that the database instance is not fork-safe, so if more than
182        one worker is used, `filter_per_worker` must be set to False.
183
184        Parameters
185        ----------
186        num_threads : int
187            Number of threads to use for data loading. Default to None, which
188            means using the number of CPU cores.
189
190        Returns
191        -------
192        feature_store : KuzuFeatureStore
193            Feature store compatible with torch_geometric.
194        graph_store : KuzuGraphStore
195            Graph store compatible with torch_geometric.
196        """
197        self.check_for_database_close()
198        from .torch_geometric_feature_store import KuzuFeatureStore
199        from .torch_geometric_graph_store import KuzuGraphStore
200
201        return (
202            KuzuFeatureStore(self, num_threads),
203            KuzuGraphStore(self, num_threads),
204        )
205
206    def _scan_node_table(
207        self,
208        table_name: str,
209        prop_name: str,
210        prop_type: str,
211        dim: int,
212        indices: IndexType,
213        num_threads: int,
214    ) -> NDArray[Any]:
215        self.check_for_database_close()
216        import numpy as np
217
218        """
219        Scan a node table from storage directly, bypassing query engine.
220        Used internally by torch_geometric remote backend only.
221        """
222        self.init_database()
223        indices_cast = np.array(indices, dtype=np.uint64)
224        result = None
225
226        if prop_type == Type.INT64.value:
227            result = np.empty(len(indices) * dim, dtype=np.int64)
228            self._database.scan_node_table_as_int64(table_name, prop_name, indices_cast, result, num_threads)
229        elif prop_type == Type.INT32.value:
230            result = np.empty(len(indices) * dim, dtype=np.int32)
231            self._database.scan_node_table_as_int32(table_name, prop_name, indices_cast, result, num_threads)
232        elif prop_type == Type.INT16.value:
233            result = np.empty(len(indices) * dim, dtype=np.int16)
234            self._database.scan_node_table_as_int16(table_name, prop_name, indices_cast, result, num_threads)
235        elif prop_type == Type.DOUBLE.value:
236            result = np.empty(len(indices) * dim, dtype=np.float64)
237            self._database.scan_node_table_as_double(table_name, prop_name, indices_cast, result, num_threads)
238        elif prop_type == Type.FLOAT.value:
239            result = np.empty(len(indices) * dim, dtype=np.float32)
240            self._database.scan_node_table_as_float(table_name, prop_name, indices_cast, result, num_threads)
241
242        if result is not None:
243            return result
244
245        msg = f"Unsupported property type: {prop_type}"
246        raise ValueError(msg)
247
248    def close(self) -> None:
249        """
250        Close the database. Once the database is closed, the lock on the database
251        files is released and the database can be opened in another process.
252
253        Note: Call to this method is not required. The Python garbage collector
254        will automatically close the database when no references to the database
255        object exist. It is recommended not to call this method explicitly. If you
256        decide to manually close the database, make sure that all the QueryResult
257        and Connection objects are closed before calling this method.
258        """
259        if self.is_closed:
260            return
261        self.is_closed = True
262        if self._database is not None:
263            self._database.close()
264            self._database: Any = None  # (type: _kuzu.Database from pybind11)
265
266    def check_for_database_close(self) -> None:
267        """
268        Check if the database is closed and raise an exception if it is.
269
270        Raises
271        ------
272        Exception
273            If the database is closed.
274
275        """
276        if not self.is_closed:
277            return
278        msg = "Database is closed"
279        raise RuntimeError(msg)

Kùzu database instance.

Database( database_path: str | pathlib.Path | None = None, *, buffer_pool_size: int = 0, max_num_threads: int = 0, compression: bool = True, lazy_init: bool = False, read_only: bool = False, max_db_size: int = 8796093022208)
29    def __init__(
30        self,
31        database_path: str | Path | None = None,
32        *,
33        buffer_pool_size: int = 0,
34        max_num_threads: int = 0,
35        compression: bool = True,
36        lazy_init: bool = False,
37        read_only: bool = False,
38        max_db_size: int = (1 << 43),
39    ):
40        """
41        Parameters
42        ----------
43        database_path : str, Path
44            The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database
45            will be created in memory.
46
47        buffer_pool_size : int
48            The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
49
50        max_num_threads : int
51            The maximum number of threads to use for executing queries.
52
53        compression : bool
54            Enable database compression.
55
56        lazy_init : bool
57            If True, the database will not be initialized until the first query.
58            This is useful when the database is not used in the main thread or
59            when the main process is forked.
60            Default to False.
61
62        read_only : bool
63            If true, the database is opened read-only. No write transactions is
64            allowed on the `Database` object. Multiple read-only `Database`
65            objects can be created with the same database path. However, there
66            cannot be multiple `Database` objects created with the same
67            database path.
68            Default to False.
69
70        max_db_size : int
71            The maximum size of the database in bytes. Note that this is introduced
72            temporarily for now to get around with the default 8TB mmap address
73             space limit some environment. This will be removed once we implemente
74             a better solution later. The value is default to 1 << 43 (8TB) under 64-bit
75             environment and 1GB under 32-bit one.
76
77        """
78        if database_path is None:
79            database_path = ":memory:"
80        if isinstance(database_path, Path):
81            database_path = str(database_path)
82
83        self.database_path = database_path
84        self.buffer_pool_size = buffer_pool_size
85        self.max_num_threads = max_num_threads
86        self.compression = compression
87        self.read_only = read_only
88        self.max_db_size = max_db_size
89        self.is_closed = False
90
91        self._database: Any = None  # (type: _kuzu.Database from pybind11)
92        if not lazy_init:
93            self.init_database()
Parameters
  • database_path (str, Path): The path to database files. If the path is not specified, or empty, or equal to :memory:, the database will be created in memory.
  • buffer_pool_size (int): The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
  • max_num_threads (int): The maximum number of threads to use for executing queries.
  • compression (bool): Enable database compression.
  • lazy_init (bool): If True, the database will not be initialized until the first query. This is useful when the database is not used in the main thread or when the main process is forked. Default to False.
  • read_only (bool): If true, the database is opened read-only. No write transactions is allowed on the Database object. Multiple read-only Database objects can be created with the same database path. However, there cannot be multiple Database objects created with the same database path. Default to False.
  • max_db_size (int): The maximum size of the database in bytes. Note that this is introduced temporarily for now to get around with the default 8TB mmap address space limit some environment. This will be removed once we implemente a better solution later. The value is default to 1 << 43 (8TB) under 64-bit environment and 1GB under 32-bit one.
database_path
buffer_pool_size
max_num_threads
compression
read_only
max_db_size
is_closed
@staticmethod
def get_version() -> str:
106    @staticmethod
107    def get_version() -> str:
108        """
109        Get the version of the database.
110
111        Returns
112        -------
113        str
114            The version of the database.
115        """
116        return _kuzu.Database.get_version()  # type: ignore[union-attr]

Get the version of the database.

Returns
  • str: The version of the database.
@staticmethod
def get_storage_version() -> int:
118    @staticmethod
119    def get_storage_version() -> int:
120        """
121        Get the storage version of the database.
122
123        Returns
124        -------
125        int
126            The storage version of the database.
127        """
128        return _kuzu.Database.get_storage_version()  # type: ignore[union-attr]

Get the storage version of the database.

Returns
  • int: The storage version of the database.
def init_database(self) -> None:
140    def init_database(self) -> None:
141        """Initialize the database."""
142        self.check_for_database_close()
143        if self._database is None:
144            self._database = _kuzu.Database(  # type: ignore[union-attr]
145                self.database_path,
146                self.buffer_pool_size,
147                self.max_num_threads,
148                self.compression,
149                self.read_only,
150                self.max_db_size,
151            )

Initialize the database.

def get_torch_geometric_remote_backend( self, num_threads: int | None = None) -> tuple[kuzu.torch_geometric_feature_store.KuzuFeatureStore, kuzu.torch_geometric_graph_store.KuzuGraphStore]:
153    def get_torch_geometric_remote_backend(
154        self, num_threads: int | None = None
155    ) -> tuple[KuzuFeatureStore, KuzuGraphStore]:
156        """
157        Use the database as the remote backend for torch_geometric.
158
159        For the interface of the remote backend, please refer to
160        https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html.
161        The current implementation is read-only and does not support edge
162        features. The IDs of the nodes are based on the internal IDs (i.e., node
163        offsets). For the remote node IDs to be consistent with the positions in
164        the output tensors, please ensure that no deletion has been performed
165        on the node tables.
166
167        The remote backend can also be plugged into the data loader of
168        torch_geometric, which is useful for mini-batch training. For example:
169
170        ```python
171            loader_kuzu = NeighborLoader(
172                data=(feature_store, graph_store),
173                num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
174                batch_size=LOADER_BATCH_SIZE,
175                input_nodes=('paper', input_nodes),
176                num_workers=4,
177                filter_per_worker=False,
178            )
179        ```
180
181        Please note that the database instance is not fork-safe, so if more than
182        one worker is used, `filter_per_worker` must be set to False.
183
184        Parameters
185        ----------
186        num_threads : int
187            Number of threads to use for data loading. Default to None, which
188            means using the number of CPU cores.
189
190        Returns
191        -------
192        feature_store : KuzuFeatureStore
193            Feature store compatible with torch_geometric.
194        graph_store : KuzuGraphStore
195            Graph store compatible with torch_geometric.
196        """
197        self.check_for_database_close()
198        from .torch_geometric_feature_store import KuzuFeatureStore
199        from .torch_geometric_graph_store import KuzuGraphStore
200
201        return (
202            KuzuFeatureStore(self, num_threads),
203            KuzuGraphStore(self, num_threads),
204        )

Use the database as the remote backend for torch_geometric.

For the interface of the remote backend, please refer to https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. The current implementation is read-only and does not support edge features. The IDs of the nodes are based on the internal IDs (i.e., node offsets). For the remote node IDs to be consistent with the positions in the output tensors, please ensure that no deletion has been performed on the node tables.

The remote backend can also be plugged into the data loader of torch_geometric, which is useful for mini-batch training. For example:

    loader_kuzu = NeighborLoader(
        data=(feature_store, graph_store),
        num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
        batch_size=LOADER_BATCH_SIZE,
        input_nodes=('paper', input_nodes),
        num_workers=4,
        filter_per_worker=False,
    )

Please note that the database instance is not fork-safe, so if more than one worker is used, filter_per_worker must be set to False.

Parameters
  • num_threads (int): Number of threads to use for data loading. Default to None, which means using the number of CPU cores.
Returns
  • feature_store (KuzuFeatureStore): Feature store compatible with torch_geometric.
  • graph_store (KuzuGraphStore): Graph store compatible with torch_geometric.
def close(self) -> None:
248    def close(self) -> None:
249        """
250        Close the database. Once the database is closed, the lock on the database
251        files is released and the database can be opened in another process.
252
253        Note: Call to this method is not required. The Python garbage collector
254        will automatically close the database when no references to the database
255        object exist. It is recommended not to call this method explicitly. If you
256        decide to manually close the database, make sure that all the QueryResult
257        and Connection objects are closed before calling this method.
258        """
259        if self.is_closed:
260            return
261        self.is_closed = True
262        if self._database is not None:
263            self._database.close()
264            self._database: Any = None  # (type: _kuzu.Database from pybind11)

Close the database. Once the database is closed, the lock on the database files is released and the database can be opened in another process.

Note: Call to this method is not required. The Python garbage collector will automatically close the database when no references to the database object exist. It is recommended not to call this method explicitly. If you decide to manually close the database, make sure that all the QueryResult and Connection objects are closed before calling this method.

def check_for_database_close(self) -> None:
266    def check_for_database_close(self) -> None:
267        """
268        Check if the database is closed and raise an exception if it is.
269
270        Raises
271        ------
272        Exception
273            If the database is closed.
274
275        """
276        if not self.is_closed:
277            return
278        msg = "Database is closed"
279        raise RuntimeError(msg)

Check if the database is closed and raise an exception if it is.

Raises
  • Exception: If the database is closed.
class PreparedStatement:
10class PreparedStatement:
11    """
12    A prepared statement is a parameterized query which can avoid planning the
13    same query for repeated execution.
14    """
15
16    def __init__(self, connection: Connection, query: str):
17        """
18        Parameters
19        ----------
20        connection : Connection
21            Connection to a database.
22        query : str
23            Query to prepare.
24        """
25        self._prepared_statement = connection._connection.prepare(query)
26
27    def is_success(self) -> bool:
28        """
29        Check if the prepared statement is successfully prepared.
30
31        Returns
32        -------
33        bool
34            True if the prepared statement is successfully prepared.
35        """
36        return self._prepared_statement.is_success()
37
38    def get_error_message(self) -> str:
39        """
40        Get the error message if the query is not prepared successfully.
41
42        Returns
43        -------
44        str
45            Error message.
46        """
47        return self._prepared_statement.get_error_message()

A prepared statement is a parameterized query which can avoid planning the same query for repeated execution.

PreparedStatement(connection: Connection, query: str)
16    def __init__(self, connection: Connection, query: str):
17        """
18        Parameters
19        ----------
20        connection : Connection
21            Connection to a database.
22        query : str
23            Query to prepare.
24        """
25        self._prepared_statement = connection._connection.prepare(query)
Parameters
  • connection (Connection): Connection to a database.
  • query (str): Query to prepare.
def is_success(self) -> bool:
27    def is_success(self) -> bool:
28        """
29        Check if the prepared statement is successfully prepared.
30
31        Returns
32        -------
33        bool
34            True if the prepared statement is successfully prepared.
35        """
36        return self._prepared_statement.is_success()

Check if the prepared statement is successfully prepared.

Returns
  • bool: True if the prepared statement is successfully prepared.
def get_error_message(self) -> str:
38    def get_error_message(self) -> str:
39        """
40        Get the error message if the query is not prepared successfully.
41
42        Returns
43        -------
44        str
45            Error message.
46        """
47        return self._prepared_statement.get_error_message()

Get the error message if the query is not prepared successfully.

Returns
  • str: Error message.
class QueryResult:
 27class QueryResult:
 28    """QueryResult stores the result of a query execution."""
 29
 30    def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult):  # type: ignore[name-defined]
 31        """
 32        Parameters
 33        ----------
 34        connection : _kuzu.Connection
 35            The underlying C++ connection object from pybind11.
 36
 37        query_result : _kuzu.QueryResult
 38            The underlying C++ query result object from pybind11.
 39
 40        """
 41        self.connection = connection
 42        self._query_result = query_result
 43        self.is_closed = False
 44
 45    def __enter__(self) -> Self:
 46        return self
 47
 48    def __exit__(
 49        self,
 50        exc_type: type[BaseException] | None,
 51        exc_value: BaseException | None,
 52        exc_traceback: TracebackType | None,
 53    ) -> None:
 54        self.close()
 55
 56    def __del__(self) -> None:
 57        self.close()
 58
 59    def check_for_query_result_close(self) -> None:
 60        """
 61        Check if the query result is closed and raise an exception if it is.
 62
 63        Raises
 64        ------
 65        Exception
 66            If the query result is closed.
 67
 68        """
 69        if self.is_closed:
 70            msg = "Query result is closed"
 71            raise RuntimeError(msg)
 72
 73    def has_next(self) -> bool:
 74        """
 75        Check if there are more rows in the query result.
 76
 77        Returns
 78        -------
 79        bool
 80            True if there are more rows in the query result, False otherwise.
 81
 82        """
 83        self.check_for_query_result_close()
 84        return self._query_result.hasNext()
 85
 86    def get_next(self) -> list[Any]:
 87        """
 88        Get the next row in the query result.
 89
 90        Returns
 91        -------
 92        list
 93            Next row in the query result.
 94
 95        """
 96        self.check_for_query_result_close()
 97        return self._query_result.getNext()
 98
 99    def close(self) -> None:
100        """Close the query result."""
101        if not self.is_closed:
102            # Allows the connection to be garbage collected if the query result
103            # is closed manually by the user.
104            self._query_result.close()
105            self.connection = None
106            self.is_closed = True
107
108    def get_as_df(self) -> pd.DataFrame:
109        """
110        Get the query result as a Pandas DataFrame.
111
112        See Also
113        --------
114        get_as_pl : Get the query result as a Polars DataFrame.
115        get_as_arrow : Get the query result as a PyArrow Table.
116
117        Returns
118        -------
119        pandas.DataFrame
120            Query result as a Pandas DataFrame.
121
122        """
123        self.check_for_query_result_close()
124
125        return self._query_result.getAsDF()
126
127    def get_as_pl(self) -> pl.DataFrame:
128        """
129        Get the query result as a Polars DataFrame.
130
131        See Also
132        --------
133        get_as_df : Get the query result as a Pandas DataFrame.
134        get_as_arrow : Get the query result as a PyArrow Table.
135
136        Returns
137        -------
138        polars.DataFrame
139            Query result as a Polars DataFrame.
140        """
141        import polars as pl
142
143        self.check_for_query_result_close()
144
145        # note: polars should always export just a single chunk,
146        # (eg: "-1") otherwise it will just need to rechunk anyway
147        return pl.from_arrow(  # type: ignore[return-value]
148            data=self.get_as_arrow(chunk_size=-1),
149        )
150
151    def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table:
152        """
153        Get the query result as a PyArrow Table.
154
155        Parameters
156        ----------
157        chunk_size : Number of rows to include in each chunk.
158            None
159                The chunk size is adaptive and depends on the number of columns in the query result.
160            -1 or 0
161                The entire result is returned as a single chunk.
162            > 0
163                The chunk size is the number of rows specified.
164
165        See Also
166        --------
167        get_as_pl : Get the query result as a Polars DataFrame.
168        get_as_df : Get the query result as a Pandas DataFrame.
169
170        Returns
171        -------
172        pyarrow.Table
173            Query result as a PyArrow Table.
174        """
175        self.check_for_query_result_close()
176
177        if chunk_size is None:
178            # Adaptive; target 10m total elements in each chunk.
179            # (eg: if we had 10 cols, this would result in a 1m row chunk_size).
180            target_n_elems = 10_000_000
181            chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
182        elif chunk_size <= 0:
183            # No chunking: return the entire result as a single chunk
184            chunk_size = self.get_num_tuples()
185
186        return self._query_result.getAsArrow(chunk_size)
187
188    def get_column_data_types(self) -> list[str]:
189        """
190        Get the data types of the columns in the query result.
191
192        Returns
193        -------
194        list
195            Data types of the columns in the query result.
196
197        """
198        self.check_for_query_result_close()
199        return self._query_result.getColumnDataTypes()
200
201    def get_column_names(self) -> list[str]:
202        """
203        Get the names of the columns in the query result.
204
205        Returns
206        -------
207        list
208            Names of the columns in the query result.
209
210        """
211        self.check_for_query_result_close()
212        return self._query_result.getColumnNames()
213
214    def get_schema(self) -> dict[str, str]:
215        """
216        Get the column schema of the query result.
217
218        Returns
219        -------
220        dict
221            Schema of the query result.
222
223        """
224        self.check_for_query_result_close()
225        return dict(
226            zip(
227                self._query_result.getColumnNames(),
228                self._query_result.getColumnDataTypes(),
229            )
230        )
231
232    def reset_iterator(self) -> None:
233        """Reset the iterator of the query result."""
234        self.check_for_query_result_close()
235        self._query_result.resetIterator()
236
237    def get_as_networkx(
238        self,
239        directed: bool = True,  # noqa: FBT001
240    ) -> nx.MultiGraph | nx.MultiDiGraph:
241        """
242        Convert the nodes and rels in query result into a NetworkX directed or undirected graph
243        with the following rules:
244        Columns with data type other than node or rel will be ignored.
245        Duplicated nodes and rels will be converted only once.
246
247        Parameters
248        ----------
249        directed : bool
250            Whether the graph should be directed. Defaults to True.
251
252        Returns
253        -------
254        networkx.MultiDiGraph or networkx.MultiGraph
255            Query result as a NetworkX graph.
256
257        """
258        self.check_for_query_result_close()
259        import networkx as nx
260
261        nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph()
262        properties_to_extract = self._get_properties_to_extract()
263
264        self.reset_iterator()
265
266        nodes = {}
267        rels = {}
268        table_to_label_dict = {}
269        table_primary_key_dict = {}
270
271        def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str:
272            node_label = node["_label"]
273            return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}"
274
275        def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]:
276            return rel["_id"]["table"], rel["_id"]["offset"]
277
278        # De-duplicate nodes and rels
279        while self.has_next():
280            row = self.get_next()
281            for i in properties_to_extract:
282                # Skip empty nodes and rels, which may be returned by
283                # OPTIONAL MATCH
284                if row[i] is None or row[i] == {}:
285                    continue
286                column_type, _ = properties_to_extract[i]
287                if column_type == Type.NODE.value:
288                    _id = row[i]["_id"]
289                    nodes[(_id["table"], _id["offset"])] = row[i]
290                    table_to_label_dict[_id["table"]] = row[i]["_label"]
291
292                elif column_type == Type.REL.value:
293                    _src = row[i]["_src"]
294                    _dst = row[i]["_dst"]
295                    rels[encode_rel_id(row[i])] = row[i]
296
297                elif column_type == Type.RECURSIVE_REL.value:
298                    for node in row[i]["_nodes"]:
299                        _id = node["_id"]
300                        nodes[(_id["table"], _id["offset"])] = node
301                        table_to_label_dict[_id["table"]] = node["_label"]
302                    for rel in row[i]["_rels"]:
303                        for key in list(rel.keys()):
304                            if rel[key] is None:
305                                del rel[key]
306                        _src = rel["_src"]
307                        _dst = rel["_dst"]
308                        rels[encode_rel_id(rel)] = rel
309
310        # Add nodes
311        for node in nodes.values():
312            _id = node["_id"]
313            node_id = node["_label"] + "_" + str(_id["offset"])
314            if node["_label"] not in table_primary_key_dict:
315                props = self.connection._get_node_property_names(node["_label"])
316                for prop_name in props:
317                    if props[prop_name]["is_primary_key"]:
318                        table_primary_key_dict[node["_label"]] = prop_name
319                        break
320            node_id = encode_node_id(node, table_primary_key_dict)
321            node[node["_label"]] = True
322            nx_graph.add_node(node_id, **node)
323
324        # Add rels
325        for rel in rels.values():
326            _src = rel["_src"]
327            _dst = rel["_dst"]
328            src_node = nodes[(_src["table"], _src["offset"])]
329            dst_node = nodes[(_dst["table"], _dst["offset"])]
330            src_id = encode_node_id(src_node, table_primary_key_dict)
331            dst_id = encode_node_id(dst_node, table_primary_key_dict)
332            nx_graph.add_edge(src_id, dst_id, **rel)
333        return nx_graph
334
335    def _get_properties_to_extract(self) -> dict[int, tuple[str, str]]:
336        column_names = self.get_column_names()
337        column_types = self.get_column_data_types()
338        properties_to_extract = {}
339
340        # Iterate over columns and extract nodes and rels, ignoring other columns
341        for i in range(len(column_names)):
342            column_name = column_names[i]
343            column_type = column_types[i]
344            if column_type in [
345                Type.NODE.value,
346                Type.REL.value,
347                Type.RECURSIVE_REL.value,
348            ]:
349                properties_to_extract[i] = (column_type, column_name)
350        return properties_to_extract
351
352    def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]:  # type: ignore[type-arg]
353        """
354        Converts the nodes and rels in query result into a PyTorch Geometric graph representation
355        torch_geometric.data.Data or torch_geometric.data.HeteroData.
356
357        For node conversion, numerical and boolean properties are directly converted into tensor and
358        stored in Data/HeteroData. For properties cannot be converted into tensor automatically
359        (please refer to the notes below for more detail), they are returned as unconverted_properties.
360
361        For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned
362        as edge_properties.
363
364        Node properties that cannot be converted into tensor automatically:
365        - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted
366          automatically.
367        - If a node property contains a null value, it cannot be converted automatically.
368        - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
369          converted automatically.
370        - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
371          is 6 for one node but 5 for another node), it cannot be converted automatically.
372
373        Additional conversion rules:
374        - Columns with data type other than node or rel will be ignored.
375        - Duplicated nodes and rels will be converted only once.
376
377        Returns
378        -------
379        torch_geometric.data.Data or torch_geometric.data.HeteroData
380            Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties
381            and edge_index tensor.
382
383        dict
384            A dictionary that maps the positional offset of each node in Data/HeteroData to its primary
385            key in the database.
386
387        dict
388            A dictionary contains node properties that cannot be converted into tensor automatically. The
389            order of values for each property is aligned with nodes in Data/HeteroData.
390
391        dict
392            A dictionary contains edge properties. The order of values for each property is aligned with
393            edge_index in Data/HeteroData.
394        """
395        self.check_for_query_result_close()
396        # Despite we are not using torch_geometric in this file, we need to
397        # import it here to throw an error early if the user does not have
398        # torch_geometric or torch installed.
399
400        converter = TorchGeometricResultConverter(self)
401        return converter.get_as_torch_geometric()
402
403    def get_execution_time(self) -> int:
404        """
405        Get the time in ms which was required for executing the query.
406
407        Returns
408        -------
409        double
410            Query execution time as double in ms.
411
412        """
413        self.check_for_query_result_close()
414        return self._query_result.getExecutionTime()
415
416    def get_compiling_time(self) -> int:
417        """
418        Get the time in ms which was required for compiling the query.
419
420        Returns
421        -------
422        double
423            Query compile time as double in ms.
424
425        """
426        self.check_for_query_result_close()
427        return self._query_result.getCompilingTime()
428
429    def get_num_tuples(self) -> int:
430        """
431        Get the number of tuples which the query returned.
432
433        Returns
434        -------
435        int
436            Number of tuples.
437
438        """
439        self.check_for_query_result_close()
440        return self._query_result.getNumTuples()

QueryResult stores the result of a query execution.

QueryResult(connection: Connection, query_result: '_kuzu.QueryResult')
30    def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult):  # type: ignore[name-defined]
31        """
32        Parameters
33        ----------
34        connection : _kuzu.Connection
35            The underlying C++ connection object from pybind11.
36
37        query_result : _kuzu.QueryResult
38            The underlying C++ query result object from pybind11.
39
40        """
41        self.connection = connection
42        self._query_result = query_result
43        self.is_closed = False
Parameters
  • connection (_kuzu.Connection): The underlying C++ connection object from pybind11.
  • query_result (_kuzu.QueryResult): The underlying C++ query result object from pybind11.
connection
is_closed
def check_for_query_result_close(self) -> None:
59    def check_for_query_result_close(self) -> None:
60        """
61        Check if the query result is closed and raise an exception if it is.
62
63        Raises
64        ------
65        Exception
66            If the query result is closed.
67
68        """
69        if self.is_closed:
70            msg = "Query result is closed"
71            raise RuntimeError(msg)

Check if the query result is closed and raise an exception if it is.

Raises
  • Exception: If the query result is closed.
def has_next(self) -> bool:
73    def has_next(self) -> bool:
74        """
75        Check if there are more rows in the query result.
76
77        Returns
78        -------
79        bool
80            True if there are more rows in the query result, False otherwise.
81
82        """
83        self.check_for_query_result_close()
84        return self._query_result.hasNext()

Check if there are more rows in the query result.

Returns
  • bool: True if there are more rows in the query result, False otherwise.
def get_next(self) -> list[typing.Any]:
86    def get_next(self) -> list[Any]:
87        """
88        Get the next row in the query result.
89
90        Returns
91        -------
92        list
93            Next row in the query result.
94
95        """
96        self.check_for_query_result_close()
97        return self._query_result.getNext()

Get the next row in the query result.

Returns
  • list: Next row in the query result.
def close(self) -> None:
 99    def close(self) -> None:
100        """Close the query result."""
101        if not self.is_closed:
102            # Allows the connection to be garbage collected if the query result
103            # is closed manually by the user.
104            self._query_result.close()
105            self.connection = None
106            self.is_closed = True

Close the query result.

def get_as_df(self) -> pandas.core.frame.DataFrame:
108    def get_as_df(self) -> pd.DataFrame:
109        """
110        Get the query result as a Pandas DataFrame.
111
112        See Also
113        --------
114        get_as_pl : Get the query result as a Polars DataFrame.
115        get_as_arrow : Get the query result as a PyArrow Table.
116
117        Returns
118        -------
119        pandas.DataFrame
120            Query result as a Pandas DataFrame.
121
122        """
123        self.check_for_query_result_close()
124
125        return self._query_result.getAsDF()

Get the query result as a Pandas DataFrame.

See Also

get_as_pl: Get the query result as a Polars DataFrame.
get_as_arrow: Get the query result as a PyArrow Table.

Returns
  • pandas.DataFrame: Query result as a Pandas DataFrame.
def get_as_pl(self) -> polars.dataframe.frame.DataFrame:
127    def get_as_pl(self) -> pl.DataFrame:
128        """
129        Get the query result as a Polars DataFrame.
130
131        See Also
132        --------
133        get_as_df : Get the query result as a Pandas DataFrame.
134        get_as_arrow : Get the query result as a PyArrow Table.
135
136        Returns
137        -------
138        polars.DataFrame
139            Query result as a Polars DataFrame.
140        """
141        import polars as pl
142
143        self.check_for_query_result_close()
144
145        # note: polars should always export just a single chunk,
146        # (eg: "-1") otherwise it will just need to rechunk anyway
147        return pl.from_arrow(  # type: ignore[return-value]
148            data=self.get_as_arrow(chunk_size=-1),
149        )

Get the query result as a Polars DataFrame.

See Also

get_as_df: Get the query result as a Pandas DataFrame.
get_as_arrow: Get the query result as a PyArrow Table.

Returns
  • polars.DataFrame: Query result as a Polars DataFrame.
def get_as_arrow(self, chunk_size: int | None = None) -> pyarrow.lib.Table:
151    def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table:
152        """
153        Get the query result as a PyArrow Table.
154
155        Parameters
156        ----------
157        chunk_size : Number of rows to include in each chunk.
158            None
159                The chunk size is adaptive and depends on the number of columns in the query result.
160            -1 or 0
161                The entire result is returned as a single chunk.
162            > 0
163                The chunk size is the number of rows specified.
164
165        See Also
166        --------
167        get_as_pl : Get the query result as a Polars DataFrame.
168        get_as_df : Get the query result as a Pandas DataFrame.
169
170        Returns
171        -------
172        pyarrow.Table
173            Query result as a PyArrow Table.
174        """
175        self.check_for_query_result_close()
176
177        if chunk_size is None:
178            # Adaptive; target 10m total elements in each chunk.
179            # (eg: if we had 10 cols, this would result in a 1m row chunk_size).
180            target_n_elems = 10_000_000
181            chunk_size = max(target_n_elems // len(self.get_column_names()), 10)
182        elif chunk_size <= 0:
183            # No chunking: return the entire result as a single chunk
184            chunk_size = self.get_num_tuples()
185
186        return self._query_result.getAsArrow(chunk_size)

Get the query result as a PyArrow Table.

Parameters
  • chunk_size (Number of rows to include in each chunk.): None The chunk size is adaptive and depends on the number of columns in the query result. -1 or 0 The entire result is returned as a single chunk. > 0 The chunk size is the number of rows specified.
See Also

get_as_pl: Get the query result as a Polars DataFrame.
get_as_df: Get the query result as a Pandas DataFrame.

Returns
  • pyarrow.Table: Query result as a PyArrow Table.
def get_column_data_types(self) -> list[str]:
188    def get_column_data_types(self) -> list[str]:
189        """
190        Get the data types of the columns in the query result.
191
192        Returns
193        -------
194        list
195            Data types of the columns in the query result.
196
197        """
198        self.check_for_query_result_close()
199        return self._query_result.getColumnDataTypes()

Get the data types of the columns in the query result.

Returns
  • list: Data types of the columns in the query result.
def get_column_names(self) -> list[str]:
201    def get_column_names(self) -> list[str]:
202        """
203        Get the names of the columns in the query result.
204
205        Returns
206        -------
207        list
208            Names of the columns in the query result.
209
210        """
211        self.check_for_query_result_close()
212        return self._query_result.getColumnNames()

Get the names of the columns in the query result.

Returns
  • list: Names of the columns in the query result.
def get_schema(self) -> dict[str, str]:
214    def get_schema(self) -> dict[str, str]:
215        """
216        Get the column schema of the query result.
217
218        Returns
219        -------
220        dict
221            Schema of the query result.
222
223        """
224        self.check_for_query_result_close()
225        return dict(
226            zip(
227                self._query_result.getColumnNames(),
228                self._query_result.getColumnDataTypes(),
229            )
230        )

Get the column schema of the query result.

Returns
  • dict: Schema of the query result.
def reset_iterator(self) -> None:
232    def reset_iterator(self) -> None:
233        """Reset the iterator of the query result."""
234        self.check_for_query_result_close()
235        self._query_result.resetIterator()

Reset the iterator of the query result.

def get_as_networkx( self, directed: bool = True) -> networkx.classes.multigraph.MultiGraph | networkx.classes.multidigraph.MultiDiGraph:
237    def get_as_networkx(
238        self,
239        directed: bool = True,  # noqa: FBT001
240    ) -> nx.MultiGraph | nx.MultiDiGraph:
241        """
242        Convert the nodes and rels in query result into a NetworkX directed or undirected graph
243        with the following rules:
244        Columns with data type other than node or rel will be ignored.
245        Duplicated nodes and rels will be converted only once.
246
247        Parameters
248        ----------
249        directed : bool
250            Whether the graph should be directed. Defaults to True.
251
252        Returns
253        -------
254        networkx.MultiDiGraph or networkx.MultiGraph
255            Query result as a NetworkX graph.
256
257        """
258        self.check_for_query_result_close()
259        import networkx as nx
260
261        nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph()
262        properties_to_extract = self._get_properties_to_extract()
263
264        self.reset_iterator()
265
266        nodes = {}
267        rels = {}
268        table_to_label_dict = {}
269        table_primary_key_dict = {}
270
271        def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str:
272            node_label = node["_label"]
273            return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}"
274
275        def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]:
276            return rel["_id"]["table"], rel["_id"]["offset"]
277
278        # De-duplicate nodes and rels
279        while self.has_next():
280            row = self.get_next()
281            for i in properties_to_extract:
282                # Skip empty nodes and rels, which may be returned by
283                # OPTIONAL MATCH
284                if row[i] is None or row[i] == {}:
285                    continue
286                column_type, _ = properties_to_extract[i]
287                if column_type == Type.NODE.value:
288                    _id = row[i]["_id"]
289                    nodes[(_id["table"], _id["offset"])] = row[i]
290                    table_to_label_dict[_id["table"]] = row[i]["_label"]
291
292                elif column_type == Type.REL.value:
293                    _src = row[i]["_src"]
294                    _dst = row[i]["_dst"]
295                    rels[encode_rel_id(row[i])] = row[i]
296
297                elif column_type == Type.RECURSIVE_REL.value:
298                    for node in row[i]["_nodes"]:
299                        _id = node["_id"]
300                        nodes[(_id["table"], _id["offset"])] = node
301                        table_to_label_dict[_id["table"]] = node["_label"]
302                    for rel in row[i]["_rels"]:
303                        for key in list(rel.keys()):
304                            if rel[key] is None:
305                                del rel[key]
306                        _src = rel["_src"]
307                        _dst = rel["_dst"]
308                        rels[encode_rel_id(rel)] = rel
309
310        # Add nodes
311        for node in nodes.values():
312            _id = node["_id"]
313            node_id = node["_label"] + "_" + str(_id["offset"])
314            if node["_label"] not in table_primary_key_dict:
315                props = self.connection._get_node_property_names(node["_label"])
316                for prop_name in props:
317                    if props[prop_name]["is_primary_key"]:
318                        table_primary_key_dict[node["_label"]] = prop_name
319                        break
320            node_id = encode_node_id(node, table_primary_key_dict)
321            node[node["_label"]] = True
322            nx_graph.add_node(node_id, **node)
323
324        # Add rels
325        for rel in rels.values():
326            _src = rel["_src"]
327            _dst = rel["_dst"]
328            src_node = nodes[(_src["table"], _src["offset"])]
329            dst_node = nodes[(_dst["table"], _dst["offset"])]
330            src_id = encode_node_id(src_node, table_primary_key_dict)
331            dst_id = encode_node_id(dst_node, table_primary_key_dict)
332            nx_graph.add_edge(src_id, dst_id, **rel)
333        return nx_graph

Convert the nodes and rels in query result into a NetworkX directed or undirected graph with the following rules: Columns with data type other than node or rel will be ignored. Duplicated nodes and rels will be converted only once.

Parameters
  • directed (bool): Whether the graph should be directed. Defaults to True.
Returns
  • networkx.MultiDiGraph or networkx.MultiGraph: Query result as a NetworkX graph.
def get_as_torch_geometric( self) -> tuple[torch_geometric.data.data.Data | torch_geometric.data.hetero_data.HeteroData, dict, dict, dict]:
352    def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]:  # type: ignore[type-arg]
353        """
354        Converts the nodes and rels in query result into a PyTorch Geometric graph representation
355        torch_geometric.data.Data or torch_geometric.data.HeteroData.
356
357        For node conversion, numerical and boolean properties are directly converted into tensor and
358        stored in Data/HeteroData. For properties cannot be converted into tensor automatically
359        (please refer to the notes below for more detail), they are returned as unconverted_properties.
360
361        For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned
362        as edge_properties.
363
364        Node properties that cannot be converted into tensor automatically:
365        - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted
366          automatically.
367        - If a node property contains a null value, it cannot be converted automatically.
368        - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be
369          converted automatically.
370        - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length
371          is 6 for one node but 5 for another node), it cannot be converted automatically.
372
373        Additional conversion rules:
374        - Columns with data type other than node or rel will be ignored.
375        - Duplicated nodes and rels will be converted only once.
376
377        Returns
378        -------
379        torch_geometric.data.Data or torch_geometric.data.HeteroData
380            Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties
381            and edge_index tensor.
382
383        dict
384            A dictionary that maps the positional offset of each node in Data/HeteroData to its primary
385            key in the database.
386
387        dict
388            A dictionary contains node properties that cannot be converted into tensor automatically. The
389            order of values for each property is aligned with nodes in Data/HeteroData.
390
391        dict
392            A dictionary contains edge properties. The order of values for each property is aligned with
393            edge_index in Data/HeteroData.
394        """
395        self.check_for_query_result_close()
396        # Despite we are not using torch_geometric in this file, we need to
397        # import it here to throw an error early if the user does not have
398        # torch_geometric or torch installed.
399
400        converter = TorchGeometricResultConverter(self)
401        return converter.get_as_torch_geometric()

Converts the nodes and rels in query result into a PyTorch Geometric graph representation torch_geometric.data.Data or torch_geometric.data.HeteroData.

For node conversion, numerical and boolean properties are directly converted into tensor and stored in Data/HeteroData. For properties cannot be converted into tensor automatically (please refer to the notes below for more detail), they are returned as unconverted_properties.

For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned as edge_properties.

Node properties that cannot be converted into tensor automatically:

  • If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted automatically.
  • If a node property contains a null value, it cannot be converted automatically.
  • If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be converted automatically.
  • If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length is 6 for one node but 5 for another node), it cannot be converted automatically.

Additional conversion rules:

  • Columns with data type other than node or rel will be ignored.
  • Duplicated nodes and rels will be converted only once.
Returns
  • torch_geometric.data.Data or torch_geometric.data.HeteroData: Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties and edge_index tensor.
  • dict: A dictionary that maps the positional offset of each node in Data/HeteroData to its primary key in the database.
  • dict: A dictionary contains node properties that cannot be converted into tensor automatically. The order of values for each property is aligned with nodes in Data/HeteroData.
  • dict: A dictionary contains edge properties. The order of values for each property is aligned with edge_index in Data/HeteroData.
def get_execution_time(self) -> int:
403    def get_execution_time(self) -> int:
404        """
405        Get the time in ms which was required for executing the query.
406
407        Returns
408        -------
409        double
410            Query execution time as double in ms.
411
412        """
413        self.check_for_query_result_close()
414        return self._query_result.getExecutionTime()

Get the time in ms which was required for executing the query.

Returns
  • double: Query execution time as double in ms.
def get_compiling_time(self) -> int:
416    def get_compiling_time(self) -> int:
417        """
418        Get the time in ms which was required for compiling the query.
419
420        Returns
421        -------
422        double
423            Query compile time as double in ms.
424
425        """
426        self.check_for_query_result_close()
427        return self._query_result.getCompilingTime()

Get the time in ms which was required for compiling the query.

Returns
  • double: Query compile time as double in ms.
def get_num_tuples(self) -> int:
429    def get_num_tuples(self) -> int:
430        """
431        Get the number of tuples which the query returned.
432
433        Returns
434        -------
435        int
436            Number of tuples.
437
438        """
439        self.check_for_query_result_close()
440        return self._query_result.getNumTuples()

Get the number of tuples which the query returned.

Returns
  • int: Number of tuples.
class Type(enum.Enum):
 5class Type(Enum):
 6    """The type of a value in the database."""
 7
 8    ANY = "ANY"
 9    NODE = "NODE"
10    REL = "REL"
11    RECURSIVE_REL = "RECURSIVE_REL"
12    SERIAL = "SERIAL"
13    BOOL = "BOOL"
14    INT64 = "INT64"
15    INT32 = "INT32"
16    INT16 = "INT16"
17    INT8 = "INT8"
18    UINT64 = "UINT64"
19    UINT32 = "UINT32"
20    UINT16 = "UINT16"
21    UINT8 = "UINT8"
22    INT128 = "INT128"
23    DOUBLE = "DOUBLE"
24    FLOAT = "FLOAT"
25    DATE = "DATE"
26    TIMESTAMP = "TIMESTAMP"
27    TIMSTAMP_TZ = "TIMESTAMP_TZ"
28    TIMESTAMP_NS = "TIMESTAMP_NS"
29    TIMESTAMP_MS = "TIMESTAMP_MS"
30    TIMESTAMP_SEC = "TIMESTAMP_SEC"
31    INTERVAL = "INTERVAL"
32    INTERNAL_ID = "INTERNAL_ID"
33    STRING = "STRING"
34    BLOB = "BLOB"
35    UUID = "UUID"
36    LIST = "LIST"
37    ARRAY = "ARRAY"
38    STRUCT = "STRUCT"
39    MAP = "MAP"
40    UNION = "UNION"

The type of a value in the database.

ANY = <Type.ANY: 'ANY'>
NODE = <Type.NODE: 'NODE'>
REL = <Type.REL: 'REL'>
RECURSIVE_REL = <Type.RECURSIVE_REL: 'RECURSIVE_REL'>
SERIAL = <Type.SERIAL: 'SERIAL'>
BOOL = <Type.BOOL: 'BOOL'>
INT64 = <Type.INT64: 'INT64'>
INT32 = <Type.INT32: 'INT32'>
INT16 = <Type.INT16: 'INT16'>
INT8 = <Type.INT8: 'INT8'>
UINT64 = <Type.UINT64: 'UINT64'>
UINT32 = <Type.UINT32: 'UINT32'>
UINT16 = <Type.UINT16: 'UINT16'>
UINT8 = <Type.UINT8: 'UINT8'>
INT128 = <Type.INT128: 'INT128'>
DOUBLE = <Type.DOUBLE: 'DOUBLE'>
FLOAT = <Type.FLOAT: 'FLOAT'>
DATE = <Type.DATE: 'DATE'>
TIMESTAMP = <Type.TIMESTAMP: 'TIMESTAMP'>
TIMSTAMP_TZ = <Type.TIMSTAMP_TZ: 'TIMESTAMP_TZ'>
TIMESTAMP_NS = <Type.TIMESTAMP_NS: 'TIMESTAMP_NS'>
TIMESTAMP_MS = <Type.TIMESTAMP_MS: 'TIMESTAMP_MS'>
TIMESTAMP_SEC = <Type.TIMESTAMP_SEC: 'TIMESTAMP_SEC'>
INTERVAL = <Type.INTERVAL: 'INTERVAL'>
INTERNAL_ID = <Type.INTERNAL_ID: 'INTERNAL_ID'>
STRING = <Type.STRING: 'STRING'>
BLOB = <Type.BLOB: 'BLOB'>
UUID = <Type.UUID: 'UUID'>
LIST = <Type.LIST: 'LIST'>
ARRAY = <Type.ARRAY: 'ARRAY'>
STRUCT = <Type.STRUCT: 'STRUCT'>
MAP = <Type.MAP: 'MAP'>
UNION = <Type.UNION: 'UNION'>
__version__
storage_version
version