kuzu
Kùzu Python API bindings.
This package provides a Python API for Kùzu graph database management system.
To install the package, run:
python3 -m pip install kuzu
Example usage:
import kuzu
db = kuzu.Database("./test")
conn = kuzu.Connection(db)
# Define the schema
conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))")
conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))")
conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)")
conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)")
# Load some data
conn.execute('COPY User FROM "user.csv"')
conn.execute('COPY City FROM "city.csv"')
conn.execute('COPY Follows FROM "follows.csv"')
conn.execute('COPY LivesIn FROM "lives-in.csv"')
# Query the data
results = conn.execute("MATCH (u:User) RETURN u.name, u.age;")
while results.has_next():
print(results.get_next())
The dataset used in this example can be found here.
1""" 2# Kùzu Python API bindings. 3 4This package provides a Python API for Kùzu graph database management system. 5 6To install the package, run: 7``` 8python3 -m pip install kuzu 9``` 10 11Example usage: 12```python 13import kuzu 14 15db = kuzu.Database("./test") 16conn = kuzu.Connection(db) 17 18# Define the schema 19conn.execute("CREATE NODE TABLE User(name STRING, age INT64, PRIMARY KEY (name))") 20conn.execute("CREATE NODE TABLE City(name STRING, population INT64, PRIMARY KEY (name))") 21conn.execute("CREATE REL TABLE Follows(FROM User TO User, since INT64)") 22conn.execute("CREATE REL TABLE LivesIn(FROM User TO City)") 23 24# Load some data 25conn.execute('COPY User FROM "user.csv"') 26conn.execute('COPY City FROM "city.csv"') 27conn.execute('COPY Follows FROM "follows.csv"') 28conn.execute('COPY LivesIn FROM "lives-in.csv"') 29 30# Query the data 31results = conn.execute("MATCH (u:User) RETURN u.name, u.age;") 32while results.has_next(): 33 print(results.get_next()) 34``` 35 36The dataset used in this example can be found [here](https://github.com/kuzudb/kuzu/tree/master/dataset/demo-db/csv). 37 38""" 39 40from __future__ import annotations 41 42import os 43import sys 44 45# Set RTLD_GLOBAL and RTLD_LAZY flags on Linux to fix the issue with loading 46# extensions 47if sys.platform == "linux": 48 original_dlopen_flags = sys.getdlopenflags() 49 sys.setdlopenflags(os.RTLD_GLOBAL | os.RTLD_LAZY) 50 51from .connection import Connection 52from .database import Database 53from .prepared_statement import PreparedStatement 54from .query_result import QueryResult 55from .types import Type 56 57 58def __getattr__(name: str) -> str | int: 59 if name in ("version", "__version__"): 60 return Database.get_version() 61 elif name == "storage_version": 62 return Database.get_storage_version() 63 else: 64 msg = f"module {__name__!r} has no attribute {name!r}" 65 raise AttributeError(msg) 66 67 68# Restore the original dlopen flags 69if sys.platform == "linux": 70 sys.setdlopenflags(original_dlopen_flags) 71 72__all__ = [ 73 "Connection", 74 "Database", 75 "PreparedStatement", 76 "QueryResult", 77 "Type", 78 "__version__", 79 "storage_version", 80 "version", 81]
23class Connection: 24 """Connection to a database.""" 25 26 def __init__(self, database: Database, num_threads: int = 0): 27 """ 28 Initialise kuzu database connection. 29 30 Parameters 31 ---------- 32 database : Database 33 Database to connect to. 34 35 num_threads : int 36 Maximum number of threads to use for executing queries. 37 38 """ 39 self._connection: Any = None # (type: _kuzu.Connection from pybind11) 40 self.database = database 41 self.num_threads = num_threads 42 self.is_closed = False 43 self.init_connection() 44 45 def __getstate__(self) -> dict[str, Any]: 46 state = { 47 "database": self.database, 48 "num_threads": self.num_threads, 49 "_connection": None, 50 } 51 return state 52 53 def init_connection(self) -> None: 54 """Establish a connection to the database, if not already initalised.""" 55 if self.is_closed: 56 error_msg = "Connection is closed." 57 raise RuntimeError(error_msg) 58 self.database.init_database() 59 if self._connection is None: 60 self._connection = _kuzu.Connection(self.database._database, self.num_threads) # type: ignore[union-attr] 61 62 def set_max_threads_for_exec(self, num_threads: int) -> None: 63 """ 64 Set the maximum number of threads for executing queries. 65 66 Parameters 67 ---------- 68 num_threads : int 69 Maximum number of threads to use for executing queries. 70 71 """ 72 self.init_connection() 73 self._connection.set_max_threads_for_exec(num_threads) 74 75 def close(self) -> None: 76 """ 77 Close the connection. 78 79 Note: Call to this method is optional. The connection will be closed 80 automatically when the object goes out of scope. 81 """ 82 if self._connection is not None: 83 self._connection.close() 84 self._connection = None 85 self.is_closed = True 86 87 def __enter__(self) -> Self: 88 return self 89 90 def __exit__( 91 self, 92 exc_type: type[BaseException] | None, 93 exc_value: BaseException | None, 94 exc_traceback: TracebackType | None, 95 ) -> None: 96 self.close() 97 98 def execute( 99 self, 100 query: str | PreparedStatement, 101 parameters: dict[str, Any] | None = None, 102 ) -> QueryResult | list[QueryResult]: 103 """ 104 Execute a query. 105 106 Parameters 107 ---------- 108 query : str | PreparedStatement 109 A prepared statement or a query string. 110 If a query string is given, a prepared statement will be created 111 automatically. 112 113 parameters : dict[str, Any] 114 Parameters for the query. 115 116 Returns 117 ------- 118 QueryResult 119 Query result. 120 121 """ 122 if parameters is None: 123 parameters = {} 124 125 self.init_connection() 126 if not isinstance(parameters, dict): 127 msg = f"Parameters must be a dict; found {type(parameters)}." 128 raise RuntimeError(msg) # noqa: TRY004 129 130 if len(parameters) == 0: 131 _query_result = self._connection.query(query) 132 else: 133 prepared_statement = self.prepare(query) if isinstance(query, str) else query 134 _query_result = self._connection.execute(prepared_statement._prepared_statement, parameters) 135 if not _query_result.isSuccess(): 136 raise RuntimeError(_query_result.getErrorMessage()) 137 current_query_result = QueryResult(self, _query_result) 138 if not _query_result.hasNextQueryResult(): 139 return current_query_result 140 all_query_results = [current_query_result] 141 while _query_result.hasNextQueryResult(): 142 _query_result = _query_result.getNextQueryResult() 143 if not _query_result.isSuccess(): 144 raise RuntimeError(_query_result.getErrorMessage()) 145 all_query_results.append(QueryResult(self, _query_result)) 146 return all_query_results 147 148 def prepare(self, query: str) -> PreparedStatement: 149 """ 150 Create a prepared statement for a query. 151 152 Parameters 153 ---------- 154 query : str 155 Query to prepare. 156 157 Returns 158 ------- 159 PreparedStatement 160 Prepared statement. 161 162 """ 163 return PreparedStatement(self, query) 164 165 def _get_node_property_names(self, table_name: str) -> dict[str, Any]: 166 LIST_START_SYMBOL = "[" 167 LIST_END_SYMBOL = "]" 168 self.init_connection() 169 query_result = self.execute(f"CALL table_info('{table_name}') RETURN *;") 170 results = {} 171 while query_result.has_next(): 172 row = query_result.get_next() 173 prop_name = row[1] 174 prop_type = row[2] 175 is_primary_key = row[4] is True 176 dimension = prop_type.count(LIST_START_SYMBOL) 177 splitted = prop_type.split(LIST_START_SYMBOL) 178 shape = [] 179 for s in splitted: 180 if LIST_END_SYMBOL not in s: 181 continue 182 s = s.split(LIST_END_SYMBOL)[0] 183 if s != "": 184 shape.append(int(s)) 185 prop_type = splitted[0] 186 results[prop_name] = { 187 "type": prop_type, 188 "dimension": dimension, 189 "is_primary_key": is_primary_key, 190 } 191 if len(shape) > 0: 192 results[prop_name]["shape"] = tuple(shape) 193 return results 194 195 def _get_node_table_names(self) -> list[Any]: 196 results = [] 197 self.init_connection() 198 query_result = self.execute("CALL show_tables() RETURN *;") 199 while query_result.has_next(): 200 row = query_result.get_next() 201 if row[2] == "NODE": 202 results.append(row[1]) 203 return results 204 205 def _get_rel_table_names(self) -> list[dict[str, Any]]: 206 results = [] 207 self.init_connection() 208 tables_result = self.execute("CALL show_tables() RETURN *;") 209 while tables_result.has_next(): 210 row = tables_result.get_next() 211 if row[2] == "REL": 212 name = row[1] 213 connections_result = self.execute(f"CALL show_connection({name!r}) RETURN *;") 214 src_dst_row = connections_result.get_next() 215 src_node = src_dst_row[0] 216 dst_node = src_dst_row[1] 217 results.append({"name": name, "src": src_node, "dst": dst_node}) 218 return results 219 220 def set_query_timeout(self, timeout_in_ms: int) -> None: 221 """ 222 Set the query timeout value in ms for executing queries. 223 224 Parameters 225 ---------- 226 timeout_in_ms : int 227 query timeout value in ms for executing queries. 228 229 """ 230 self.init_connection() 231 self._connection.set_query_timeout(timeout_in_ms) 232 233 def create_function( 234 self, 235 name: str, 236 udf: Callable[[...], Any], 237 params_type: list[Type | str] | None = None, 238 return_type: Type | str = "", 239 *, 240 default_null_handling: bool = True, 241 catch_exceptions: bool = False, 242 ) -> None: 243 """ 244 Sets a User Defined Function (UDF) to use in cypher queries. 245 246 Parameters 247 ---------- 248 name: str 249 name of function 250 251 udf: Callable[[...], Any] 252 function to be executed 253 254 params_type: Optional[list[Type]] 255 list of Type enums to describe the input parameters 256 257 return_type: Optional[Type] 258 a Type enum to describe the returned value 259 260 default_null_handling: Optional[bool] 261 if true, when any parameter is null, the resulting value will be null 262 263 catch_exceptions: Optional[bool] 264 if true, when an exception is thrown from python, the function output will be null 265 Otherwise, the exception will be rethrown 266 """ 267 if params_type is None: 268 params_type = [] 269 parsed_params_type = [x if type(x) is str else x.value for x in params_type] 270 if type(return_type) is not str: 271 return_type = return_type.value 272 273 self._connection.create_function( 274 name=name, 275 udf=udf, 276 params_type=parsed_params_type, 277 return_value=return_type, 278 default_null=default_null_handling, 279 catch_exceptions=catch_exceptions, 280 ) 281 282 def remove_function(self, name: str) -> None: 283 """ 284 Removes a User Defined Function (UDF). 285 286 Parameters 287 ---------- 288 name: str 289 name of function to be removed. 290 """ 291 self._connection.remove_function(name)
Connection to a database.
26 def __init__(self, database: Database, num_threads: int = 0): 27 """ 28 Initialise kuzu database connection. 29 30 Parameters 31 ---------- 32 database : Database 33 Database to connect to. 34 35 num_threads : int 36 Maximum number of threads to use for executing queries. 37 38 """ 39 self._connection: Any = None # (type: _kuzu.Connection from pybind11) 40 self.database = database 41 self.num_threads = num_threads 42 self.is_closed = False 43 self.init_connection()
Initialise kuzu database connection.
Parameters
- database (Database): Database to connect to.
- num_threads (int): Maximum number of threads to use for executing queries.
53 def init_connection(self) -> None: 54 """Establish a connection to the database, if not already initalised.""" 55 if self.is_closed: 56 error_msg = "Connection is closed." 57 raise RuntimeError(error_msg) 58 self.database.init_database() 59 if self._connection is None: 60 self._connection = _kuzu.Connection(self.database._database, self.num_threads) # type: ignore[union-attr]
Establish a connection to the database, if not already initalised.
62 def set_max_threads_for_exec(self, num_threads: int) -> None: 63 """ 64 Set the maximum number of threads for executing queries. 65 66 Parameters 67 ---------- 68 num_threads : int 69 Maximum number of threads to use for executing queries. 70 71 """ 72 self.init_connection() 73 self._connection.set_max_threads_for_exec(num_threads)
Set the maximum number of threads for executing queries.
Parameters
- num_threads (int): Maximum number of threads to use for executing queries.
75 def close(self) -> None: 76 """ 77 Close the connection. 78 79 Note: Call to this method is optional. The connection will be closed 80 automatically when the object goes out of scope. 81 """ 82 if self._connection is not None: 83 self._connection.close() 84 self._connection = None 85 self.is_closed = True
Close the connection.
Note: Call to this method is optional. The connection will be closed automatically when the object goes out of scope.
98 def execute( 99 self, 100 query: str | PreparedStatement, 101 parameters: dict[str, Any] | None = None, 102 ) -> QueryResult | list[QueryResult]: 103 """ 104 Execute a query. 105 106 Parameters 107 ---------- 108 query : str | PreparedStatement 109 A prepared statement or a query string. 110 If a query string is given, a prepared statement will be created 111 automatically. 112 113 parameters : dict[str, Any] 114 Parameters for the query. 115 116 Returns 117 ------- 118 QueryResult 119 Query result. 120 121 """ 122 if parameters is None: 123 parameters = {} 124 125 self.init_connection() 126 if not isinstance(parameters, dict): 127 msg = f"Parameters must be a dict; found {type(parameters)}." 128 raise RuntimeError(msg) # noqa: TRY004 129 130 if len(parameters) == 0: 131 _query_result = self._connection.query(query) 132 else: 133 prepared_statement = self.prepare(query) if isinstance(query, str) else query 134 _query_result = self._connection.execute(prepared_statement._prepared_statement, parameters) 135 if not _query_result.isSuccess(): 136 raise RuntimeError(_query_result.getErrorMessage()) 137 current_query_result = QueryResult(self, _query_result) 138 if not _query_result.hasNextQueryResult(): 139 return current_query_result 140 all_query_results = [current_query_result] 141 while _query_result.hasNextQueryResult(): 142 _query_result = _query_result.getNextQueryResult() 143 if not _query_result.isSuccess(): 144 raise RuntimeError(_query_result.getErrorMessage()) 145 all_query_results.append(QueryResult(self, _query_result)) 146 return all_query_results
Execute a query.
Parameters
- query (str | PreparedStatement): A prepared statement or a query string. If a query string is given, a prepared statement will be created automatically.
- parameters (dict[str, Any]): Parameters for the query.
Returns
- QueryResult: Query result.
148 def prepare(self, query: str) -> PreparedStatement: 149 """ 150 Create a prepared statement for a query. 151 152 Parameters 153 ---------- 154 query : str 155 Query to prepare. 156 157 Returns 158 ------- 159 PreparedStatement 160 Prepared statement. 161 162 """ 163 return PreparedStatement(self, query)
Create a prepared statement for a query.
Parameters
- query (str): Query to prepare.
Returns
- PreparedStatement: Prepared statement.
220 def set_query_timeout(self, timeout_in_ms: int) -> None: 221 """ 222 Set the query timeout value in ms for executing queries. 223 224 Parameters 225 ---------- 226 timeout_in_ms : int 227 query timeout value in ms for executing queries. 228 229 """ 230 self.init_connection() 231 self._connection.set_query_timeout(timeout_in_ms)
Set the query timeout value in ms for executing queries.
Parameters
- timeout_in_ms (int): query timeout value in ms for executing queries.
233 def create_function( 234 self, 235 name: str, 236 udf: Callable[[...], Any], 237 params_type: list[Type | str] | None = None, 238 return_type: Type | str = "", 239 *, 240 default_null_handling: bool = True, 241 catch_exceptions: bool = False, 242 ) -> None: 243 """ 244 Sets a User Defined Function (UDF) to use in cypher queries. 245 246 Parameters 247 ---------- 248 name: str 249 name of function 250 251 udf: Callable[[...], Any] 252 function to be executed 253 254 params_type: Optional[list[Type]] 255 list of Type enums to describe the input parameters 256 257 return_type: Optional[Type] 258 a Type enum to describe the returned value 259 260 default_null_handling: Optional[bool] 261 if true, when any parameter is null, the resulting value will be null 262 263 catch_exceptions: Optional[bool] 264 if true, when an exception is thrown from python, the function output will be null 265 Otherwise, the exception will be rethrown 266 """ 267 if params_type is None: 268 params_type = [] 269 parsed_params_type = [x if type(x) is str else x.value for x in params_type] 270 if type(return_type) is not str: 271 return_type = return_type.value 272 273 self._connection.create_function( 274 name=name, 275 udf=udf, 276 params_type=parsed_params_type, 277 return_value=return_type, 278 default_null=default_null_handling, 279 catch_exceptions=catch_exceptions, 280 )
Sets a User Defined Function (UDF) to use in cypher queries.
Parameters
- name (str): name of function
- udf (Callable[[...], Any]): function to be executed
- params_type (Optional[list[Type]]): list of Type enums to describe the input parameters
- return_type (Optional[Type]): a Type enum to describe the returned value
- default_null_handling (Optional[bool]): if true, when any parameter is null, the resulting value will be null
- catch_exceptions (Optional[bool]): if true, when an exception is thrown from python, the function output will be null Otherwise, the exception will be rethrown
282 def remove_function(self, name: str) -> None: 283 """ 284 Removes a User Defined Function (UDF). 285 286 Parameters 287 ---------- 288 name: str 289 name of function to be removed. 290 """ 291 self._connection.remove_function(name)
Removes a User Defined Function (UDF).
Parameters
- name (str): name of function to be removed.
26class Database: 27 """Kùzu database instance.""" 28 29 def __init__( 30 self, 31 database_path: str | Path | None = None, 32 *, 33 buffer_pool_size: int = 0, 34 max_num_threads: int = 0, 35 compression: bool = True, 36 lazy_init: bool = False, 37 read_only: bool = False, 38 max_db_size: int = (1 << 43), 39 ): 40 """ 41 Parameters 42 ---------- 43 database_path : str, Path 44 The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database 45 will be created in memory. 46 47 buffer_pool_size : int 48 The maximum size of buffer pool in bytes. Defaults to ~80% of system memory. 49 50 max_num_threads : int 51 The maximum number of threads to use for executing queries. 52 53 compression : bool 54 Enable database compression. 55 56 lazy_init : bool 57 If True, the database will not be initialized until the first query. 58 This is useful when the database is not used in the main thread or 59 when the main process is forked. 60 Default to False. 61 62 read_only : bool 63 If true, the database is opened read-only. No write transactions is 64 allowed on the `Database` object. Multiple read-only `Database` 65 objects can be created with the same database path. However, there 66 cannot be multiple `Database` objects created with the same 67 database path. 68 Default to False. 69 70 max_db_size : int 71 The maximum size of the database in bytes. Note that this is introduced 72 temporarily for now to get around with the default 8TB mmap address 73 space limit some environment. This will be removed once we implemente 74 a better solution later. The value is default to 1 << 43 (8TB) under 64-bit 75 environment and 1GB under 32-bit one. 76 77 """ 78 if database_path is None: 79 database_path = ":memory:" 80 if isinstance(database_path, Path): 81 database_path = str(database_path) 82 83 self.database_path = database_path 84 self.buffer_pool_size = buffer_pool_size 85 self.max_num_threads = max_num_threads 86 self.compression = compression 87 self.read_only = read_only 88 self.max_db_size = max_db_size 89 self.is_closed = False 90 91 self._database: Any = None # (type: _kuzu.Database from pybind11) 92 if not lazy_init: 93 self.init_database() 94 95 def __enter__(self) -> Self: 96 return self 97 98 def __exit__( 99 self, 100 exc_type: type[BaseException] | None, 101 exc_value: BaseException | None, 102 exc_traceback: TracebackType | None, 103 ) -> None: 104 self.close() 105 106 @staticmethod 107 def get_version() -> str: 108 """ 109 Get the version of the database. 110 111 Returns 112 ------- 113 str 114 The version of the database. 115 """ 116 return _kuzu.Database.get_version() # type: ignore[union-attr] 117 118 @staticmethod 119 def get_storage_version() -> int: 120 """ 121 Get the storage version of the database. 122 123 Returns 124 ------- 125 int 126 The storage version of the database. 127 """ 128 return _kuzu.Database.get_storage_version() # type: ignore[union-attr] 129 130 def __getstate__(self) -> dict[str, Any]: 131 state = { 132 "database_path": self.database_path, 133 "buffer_pool_size": self.buffer_pool_size, 134 "compression": self.compression, 135 "read_only": self.read_only, 136 "_database": None, 137 } 138 return state 139 140 def init_database(self) -> None: 141 """Initialize the database.""" 142 self.check_for_database_close() 143 if self._database is None: 144 self._database = _kuzu.Database( # type: ignore[union-attr] 145 self.database_path, 146 self.buffer_pool_size, 147 self.max_num_threads, 148 self.compression, 149 self.read_only, 150 self.max_db_size, 151 ) 152 153 def get_torch_geometric_remote_backend( 154 self, num_threads: int | None = None 155 ) -> tuple[KuzuFeatureStore, KuzuGraphStore]: 156 """ 157 Use the database as the remote backend for torch_geometric. 158 159 For the interface of the remote backend, please refer to 160 https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. 161 The current implementation is read-only and does not support edge 162 features. The IDs of the nodes are based on the internal IDs (i.e., node 163 offsets). For the remote node IDs to be consistent with the positions in 164 the output tensors, please ensure that no deletion has been performed 165 on the node tables. 166 167 The remote backend can also be plugged into the data loader of 168 torch_geometric, which is useful for mini-batch training. For example: 169 170 ```python 171 loader_kuzu = NeighborLoader( 172 data=(feature_store, graph_store), 173 num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]}, 174 batch_size=LOADER_BATCH_SIZE, 175 input_nodes=('paper', input_nodes), 176 num_workers=4, 177 filter_per_worker=False, 178 ) 179 ``` 180 181 Please note that the database instance is not fork-safe, so if more than 182 one worker is used, `filter_per_worker` must be set to False. 183 184 Parameters 185 ---------- 186 num_threads : int 187 Number of threads to use for data loading. Default to None, which 188 means using the number of CPU cores. 189 190 Returns 191 ------- 192 feature_store : KuzuFeatureStore 193 Feature store compatible with torch_geometric. 194 graph_store : KuzuGraphStore 195 Graph store compatible with torch_geometric. 196 """ 197 self.check_for_database_close() 198 from .torch_geometric_feature_store import KuzuFeatureStore 199 from .torch_geometric_graph_store import KuzuGraphStore 200 201 return ( 202 KuzuFeatureStore(self, num_threads), 203 KuzuGraphStore(self, num_threads), 204 ) 205 206 def _scan_node_table( 207 self, 208 table_name: str, 209 prop_name: str, 210 prop_type: str, 211 dim: int, 212 indices: IndexType, 213 num_threads: int, 214 ) -> NDArray[Any]: 215 self.check_for_database_close() 216 import numpy as np 217 218 """ 219 Scan a node table from storage directly, bypassing query engine. 220 Used internally by torch_geometric remote backend only. 221 """ 222 self.init_database() 223 indices_cast = np.array(indices, dtype=np.uint64) 224 result = None 225 226 if prop_type == Type.INT64.value: 227 result = np.empty(len(indices) * dim, dtype=np.int64) 228 self._database.scan_node_table_as_int64(table_name, prop_name, indices_cast, result, num_threads) 229 elif prop_type == Type.INT32.value: 230 result = np.empty(len(indices) * dim, dtype=np.int32) 231 self._database.scan_node_table_as_int32(table_name, prop_name, indices_cast, result, num_threads) 232 elif prop_type == Type.INT16.value: 233 result = np.empty(len(indices) * dim, dtype=np.int16) 234 self._database.scan_node_table_as_int16(table_name, prop_name, indices_cast, result, num_threads) 235 elif prop_type == Type.DOUBLE.value: 236 result = np.empty(len(indices) * dim, dtype=np.float64) 237 self._database.scan_node_table_as_double(table_name, prop_name, indices_cast, result, num_threads) 238 elif prop_type == Type.FLOAT.value: 239 result = np.empty(len(indices) * dim, dtype=np.float32) 240 self._database.scan_node_table_as_float(table_name, prop_name, indices_cast, result, num_threads) 241 242 if result is not None: 243 return result 244 245 msg = f"Unsupported property type: {prop_type}" 246 raise ValueError(msg) 247 248 def close(self) -> None: 249 """ 250 Close the database. Once the database is closed, the lock on the database 251 files is released and the database can be opened in another process. 252 253 Note: Call to this method is not required. The Python garbage collector 254 will automatically close the database when no references to the database 255 object exist. It is recommended not to call this method explicitly. If you 256 decide to manually close the database, make sure that all the QueryResult 257 and Connection objects are closed before calling this method. 258 """ 259 if self.is_closed: 260 return 261 self.is_closed = True 262 if self._database is not None: 263 self._database.close() 264 self._database: Any = None # (type: _kuzu.Database from pybind11) 265 266 def check_for_database_close(self) -> None: 267 """ 268 Check if the database is closed and raise an exception if it is. 269 270 Raises 271 ------ 272 Exception 273 If the database is closed. 274 275 """ 276 if not self.is_closed: 277 return 278 msg = "Database is closed" 279 raise RuntimeError(msg)
Kùzu database instance.
29 def __init__( 30 self, 31 database_path: str | Path | None = None, 32 *, 33 buffer_pool_size: int = 0, 34 max_num_threads: int = 0, 35 compression: bool = True, 36 lazy_init: bool = False, 37 read_only: bool = False, 38 max_db_size: int = (1 << 43), 39 ): 40 """ 41 Parameters 42 ---------- 43 database_path : str, Path 44 The path to database files. If the path is not specified, or empty, or equal to `:memory:`, the database 45 will be created in memory. 46 47 buffer_pool_size : int 48 The maximum size of buffer pool in bytes. Defaults to ~80% of system memory. 49 50 max_num_threads : int 51 The maximum number of threads to use for executing queries. 52 53 compression : bool 54 Enable database compression. 55 56 lazy_init : bool 57 If True, the database will not be initialized until the first query. 58 This is useful when the database is not used in the main thread or 59 when the main process is forked. 60 Default to False. 61 62 read_only : bool 63 If true, the database is opened read-only. No write transactions is 64 allowed on the `Database` object. Multiple read-only `Database` 65 objects can be created with the same database path. However, there 66 cannot be multiple `Database` objects created with the same 67 database path. 68 Default to False. 69 70 max_db_size : int 71 The maximum size of the database in bytes. Note that this is introduced 72 temporarily for now to get around with the default 8TB mmap address 73 space limit some environment. This will be removed once we implemente 74 a better solution later. The value is default to 1 << 43 (8TB) under 64-bit 75 environment and 1GB under 32-bit one. 76 77 """ 78 if database_path is None: 79 database_path = ":memory:" 80 if isinstance(database_path, Path): 81 database_path = str(database_path) 82 83 self.database_path = database_path 84 self.buffer_pool_size = buffer_pool_size 85 self.max_num_threads = max_num_threads 86 self.compression = compression 87 self.read_only = read_only 88 self.max_db_size = max_db_size 89 self.is_closed = False 90 91 self._database: Any = None # (type: _kuzu.Database from pybind11) 92 if not lazy_init: 93 self.init_database()
Parameters
- database_path (str, Path):
The path to database files. If the path is not specified, or empty, or equal to
:memory:
, the database will be created in memory. - buffer_pool_size (int): The maximum size of buffer pool in bytes. Defaults to ~80% of system memory.
- max_num_threads (int): The maximum number of threads to use for executing queries.
- compression (bool): Enable database compression.
- lazy_init (bool): If True, the database will not be initialized until the first query. This is useful when the database is not used in the main thread or when the main process is forked. Default to False.
- read_only (bool):
If true, the database is opened read-only. No write transactions is
allowed on the
Database
object. Multiple read-onlyDatabase
objects can be created with the same database path. However, there cannot be multipleDatabase
objects created with the same database path. Default to False. - max_db_size (int): The maximum size of the database in bytes. Note that this is introduced temporarily for now to get around with the default 8TB mmap address space limit some environment. This will be removed once we implemente a better solution later. The value is default to 1 << 43 (8TB) under 64-bit environment and 1GB under 32-bit one.
106 @staticmethod 107 def get_version() -> str: 108 """ 109 Get the version of the database. 110 111 Returns 112 ------- 113 str 114 The version of the database. 115 """ 116 return _kuzu.Database.get_version() # type: ignore[union-attr]
Get the version of the database.
Returns
- str: The version of the database.
118 @staticmethod 119 def get_storage_version() -> int: 120 """ 121 Get the storage version of the database. 122 123 Returns 124 ------- 125 int 126 The storage version of the database. 127 """ 128 return _kuzu.Database.get_storage_version() # type: ignore[union-attr]
Get the storage version of the database.
Returns
- int: The storage version of the database.
140 def init_database(self) -> None: 141 """Initialize the database.""" 142 self.check_for_database_close() 143 if self._database is None: 144 self._database = _kuzu.Database( # type: ignore[union-attr] 145 self.database_path, 146 self.buffer_pool_size, 147 self.max_num_threads, 148 self.compression, 149 self.read_only, 150 self.max_db_size, 151 )
Initialize the database.
153 def get_torch_geometric_remote_backend( 154 self, num_threads: int | None = None 155 ) -> tuple[KuzuFeatureStore, KuzuGraphStore]: 156 """ 157 Use the database as the remote backend for torch_geometric. 158 159 For the interface of the remote backend, please refer to 160 https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. 161 The current implementation is read-only and does not support edge 162 features. The IDs of the nodes are based on the internal IDs (i.e., node 163 offsets). For the remote node IDs to be consistent with the positions in 164 the output tensors, please ensure that no deletion has been performed 165 on the node tables. 166 167 The remote backend can also be plugged into the data loader of 168 torch_geometric, which is useful for mini-batch training. For example: 169 170 ```python 171 loader_kuzu = NeighborLoader( 172 data=(feature_store, graph_store), 173 num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]}, 174 batch_size=LOADER_BATCH_SIZE, 175 input_nodes=('paper', input_nodes), 176 num_workers=4, 177 filter_per_worker=False, 178 ) 179 ``` 180 181 Please note that the database instance is not fork-safe, so if more than 182 one worker is used, `filter_per_worker` must be set to False. 183 184 Parameters 185 ---------- 186 num_threads : int 187 Number of threads to use for data loading. Default to None, which 188 means using the number of CPU cores. 189 190 Returns 191 ------- 192 feature_store : KuzuFeatureStore 193 Feature store compatible with torch_geometric. 194 graph_store : KuzuGraphStore 195 Graph store compatible with torch_geometric. 196 """ 197 self.check_for_database_close() 198 from .torch_geometric_feature_store import KuzuFeatureStore 199 from .torch_geometric_graph_store import KuzuGraphStore 200 201 return ( 202 KuzuFeatureStore(self, num_threads), 203 KuzuGraphStore(self, num_threads), 204 )
Use the database as the remote backend for torch_geometric.
For the interface of the remote backend, please refer to https://pytorch-geometric.readthedocs.io/en/latest/advanced/remote.html. The current implementation is read-only and does not support edge features. The IDs of the nodes are based on the internal IDs (i.e., node offsets). For the remote node IDs to be consistent with the positions in the output tensors, please ensure that no deletion has been performed on the node tables.
The remote backend can also be plugged into the data loader of torch_geometric, which is useful for mini-batch training. For example:
loader_kuzu = NeighborLoader(
data=(feature_store, graph_store),
num_neighbors={('paper', 'cites', 'paper'): [12, 12, 12]},
batch_size=LOADER_BATCH_SIZE,
input_nodes=('paper', input_nodes),
num_workers=4,
filter_per_worker=False,
)
Please note that the database instance is not fork-safe, so if more than
one worker is used, filter_per_worker
must be set to False.
Parameters
- num_threads (int): Number of threads to use for data loading. Default to None, which means using the number of CPU cores.
Returns
- feature_store (KuzuFeatureStore): Feature store compatible with torch_geometric.
- graph_store (KuzuGraphStore): Graph store compatible with torch_geometric.
248 def close(self) -> None: 249 """ 250 Close the database. Once the database is closed, the lock on the database 251 files is released and the database can be opened in another process. 252 253 Note: Call to this method is not required. The Python garbage collector 254 will automatically close the database when no references to the database 255 object exist. It is recommended not to call this method explicitly. If you 256 decide to manually close the database, make sure that all the QueryResult 257 and Connection objects are closed before calling this method. 258 """ 259 if self.is_closed: 260 return 261 self.is_closed = True 262 if self._database is not None: 263 self._database.close() 264 self._database: Any = None # (type: _kuzu.Database from pybind11)
Close the database. Once the database is closed, the lock on the database files is released and the database can be opened in another process.
Note: Call to this method is not required. The Python garbage collector will automatically close the database when no references to the database object exist. It is recommended not to call this method explicitly. If you decide to manually close the database, make sure that all the QueryResult and Connection objects are closed before calling this method.
266 def check_for_database_close(self) -> None: 267 """ 268 Check if the database is closed and raise an exception if it is. 269 270 Raises 271 ------ 272 Exception 273 If the database is closed. 274 275 """ 276 if not self.is_closed: 277 return 278 msg = "Database is closed" 279 raise RuntimeError(msg)
Check if the database is closed and raise an exception if it is.
Raises
- Exception: If the database is closed.
10class PreparedStatement: 11 """ 12 A prepared statement is a parameterized query which can avoid planning the 13 same query for repeated execution. 14 """ 15 16 def __init__(self, connection: Connection, query: str): 17 """ 18 Parameters 19 ---------- 20 connection : Connection 21 Connection to a database. 22 query : str 23 Query to prepare. 24 """ 25 self._prepared_statement = connection._connection.prepare(query) 26 27 def is_success(self) -> bool: 28 """ 29 Check if the prepared statement is successfully prepared. 30 31 Returns 32 ------- 33 bool 34 True if the prepared statement is successfully prepared. 35 """ 36 return self._prepared_statement.is_success() 37 38 def get_error_message(self) -> str: 39 """ 40 Get the error message if the query is not prepared successfully. 41 42 Returns 43 ------- 44 str 45 Error message. 46 """ 47 return self._prepared_statement.get_error_message()
A prepared statement is a parameterized query which can avoid planning the same query for repeated execution.
16 def __init__(self, connection: Connection, query: str): 17 """ 18 Parameters 19 ---------- 20 connection : Connection 21 Connection to a database. 22 query : str 23 Query to prepare. 24 """ 25 self._prepared_statement = connection._connection.prepare(query)
Parameters
- connection (Connection): Connection to a database.
- query (str): Query to prepare.
27 def is_success(self) -> bool: 28 """ 29 Check if the prepared statement is successfully prepared. 30 31 Returns 32 ------- 33 bool 34 True if the prepared statement is successfully prepared. 35 """ 36 return self._prepared_statement.is_success()
Check if the prepared statement is successfully prepared.
Returns
- bool: True if the prepared statement is successfully prepared.
38 def get_error_message(self) -> str: 39 """ 40 Get the error message if the query is not prepared successfully. 41 42 Returns 43 ------- 44 str 45 Error message. 46 """ 47 return self._prepared_statement.get_error_message()
Get the error message if the query is not prepared successfully.
Returns
- str: Error message.
27class QueryResult: 28 """QueryResult stores the result of a query execution.""" 29 30 def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult): # type: ignore[name-defined] 31 """ 32 Parameters 33 ---------- 34 connection : _kuzu.Connection 35 The underlying C++ connection object from pybind11. 36 37 query_result : _kuzu.QueryResult 38 The underlying C++ query result object from pybind11. 39 40 """ 41 self.connection = connection 42 self._query_result = query_result 43 self.is_closed = False 44 45 def __enter__(self) -> Self: 46 return self 47 48 def __exit__( 49 self, 50 exc_type: type[BaseException] | None, 51 exc_value: BaseException | None, 52 exc_traceback: TracebackType | None, 53 ) -> None: 54 self.close() 55 56 def __del__(self) -> None: 57 self.close() 58 59 def check_for_query_result_close(self) -> None: 60 """ 61 Check if the query result is closed and raise an exception if it is. 62 63 Raises 64 ------ 65 Exception 66 If the query result is closed. 67 68 """ 69 if self.is_closed: 70 msg = "Query result is closed" 71 raise RuntimeError(msg) 72 73 def has_next(self) -> bool: 74 """ 75 Check if there are more rows in the query result. 76 77 Returns 78 ------- 79 bool 80 True if there are more rows in the query result, False otherwise. 81 82 """ 83 self.check_for_query_result_close() 84 return self._query_result.hasNext() 85 86 def get_next(self) -> list[Any]: 87 """ 88 Get the next row in the query result. 89 90 Returns 91 ------- 92 list 93 Next row in the query result. 94 95 """ 96 self.check_for_query_result_close() 97 return self._query_result.getNext() 98 99 def close(self) -> None: 100 """Close the query result.""" 101 if not self.is_closed: 102 # Allows the connection to be garbage collected if the query result 103 # is closed manually by the user. 104 self._query_result.close() 105 self.connection = None 106 self.is_closed = True 107 108 def get_as_df(self) -> pd.DataFrame: 109 """ 110 Get the query result as a Pandas DataFrame. 111 112 See Also 113 -------- 114 get_as_pl : Get the query result as a Polars DataFrame. 115 get_as_arrow : Get the query result as a PyArrow Table. 116 117 Returns 118 ------- 119 pandas.DataFrame 120 Query result as a Pandas DataFrame. 121 122 """ 123 self.check_for_query_result_close() 124 125 return self._query_result.getAsDF() 126 127 def get_as_pl(self) -> pl.DataFrame: 128 """ 129 Get the query result as a Polars DataFrame. 130 131 See Also 132 -------- 133 get_as_df : Get the query result as a Pandas DataFrame. 134 get_as_arrow : Get the query result as a PyArrow Table. 135 136 Returns 137 ------- 138 polars.DataFrame 139 Query result as a Polars DataFrame. 140 """ 141 import polars as pl 142 143 self.check_for_query_result_close() 144 145 # note: polars should always export just a single chunk, 146 # (eg: "-1") otherwise it will just need to rechunk anyway 147 return pl.from_arrow( # type: ignore[return-value] 148 data=self.get_as_arrow(chunk_size=-1), 149 ) 150 151 def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table: 152 """ 153 Get the query result as a PyArrow Table. 154 155 Parameters 156 ---------- 157 chunk_size : Number of rows to include in each chunk. 158 None 159 The chunk size is adaptive and depends on the number of columns in the query result. 160 -1 or 0 161 The entire result is returned as a single chunk. 162 > 0 163 The chunk size is the number of rows specified. 164 165 See Also 166 -------- 167 get_as_pl : Get the query result as a Polars DataFrame. 168 get_as_df : Get the query result as a Pandas DataFrame. 169 170 Returns 171 ------- 172 pyarrow.Table 173 Query result as a PyArrow Table. 174 """ 175 self.check_for_query_result_close() 176 177 if chunk_size is None: 178 # Adaptive; target 10m total elements in each chunk. 179 # (eg: if we had 10 cols, this would result in a 1m row chunk_size). 180 target_n_elems = 10_000_000 181 chunk_size = max(target_n_elems // len(self.get_column_names()), 10) 182 elif chunk_size <= 0: 183 # No chunking: return the entire result as a single chunk 184 chunk_size = self.get_num_tuples() 185 186 return self._query_result.getAsArrow(chunk_size) 187 188 def get_column_data_types(self) -> list[str]: 189 """ 190 Get the data types of the columns in the query result. 191 192 Returns 193 ------- 194 list 195 Data types of the columns in the query result. 196 197 """ 198 self.check_for_query_result_close() 199 return self._query_result.getColumnDataTypes() 200 201 def get_column_names(self) -> list[str]: 202 """ 203 Get the names of the columns in the query result. 204 205 Returns 206 ------- 207 list 208 Names of the columns in the query result. 209 210 """ 211 self.check_for_query_result_close() 212 return self._query_result.getColumnNames() 213 214 def get_schema(self) -> dict[str, str]: 215 """ 216 Get the column schema of the query result. 217 218 Returns 219 ------- 220 dict 221 Schema of the query result. 222 223 """ 224 self.check_for_query_result_close() 225 return dict( 226 zip( 227 self._query_result.getColumnNames(), 228 self._query_result.getColumnDataTypes(), 229 ) 230 ) 231 232 def reset_iterator(self) -> None: 233 """Reset the iterator of the query result.""" 234 self.check_for_query_result_close() 235 self._query_result.resetIterator() 236 237 def get_as_networkx( 238 self, 239 directed: bool = True, # noqa: FBT001 240 ) -> nx.MultiGraph | nx.MultiDiGraph: 241 """ 242 Convert the nodes and rels in query result into a NetworkX directed or undirected graph 243 with the following rules: 244 Columns with data type other than node or rel will be ignored. 245 Duplicated nodes and rels will be converted only once. 246 247 Parameters 248 ---------- 249 directed : bool 250 Whether the graph should be directed. Defaults to True. 251 252 Returns 253 ------- 254 networkx.MultiDiGraph or networkx.MultiGraph 255 Query result as a NetworkX graph. 256 257 """ 258 self.check_for_query_result_close() 259 import networkx as nx 260 261 nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph() 262 properties_to_extract = self._get_properties_to_extract() 263 264 self.reset_iterator() 265 266 nodes = {} 267 rels = {} 268 table_to_label_dict = {} 269 table_primary_key_dict = {} 270 271 def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str: 272 node_label = node["_label"] 273 return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}" 274 275 def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]: 276 return rel["_id"]["table"], rel["_id"]["offset"] 277 278 # De-duplicate nodes and rels 279 while self.has_next(): 280 row = self.get_next() 281 for i in properties_to_extract: 282 # Skip empty nodes and rels, which may be returned by 283 # OPTIONAL MATCH 284 if row[i] is None or row[i] == {}: 285 continue 286 column_type, _ = properties_to_extract[i] 287 if column_type == Type.NODE.value: 288 _id = row[i]["_id"] 289 nodes[(_id["table"], _id["offset"])] = row[i] 290 table_to_label_dict[_id["table"]] = row[i]["_label"] 291 292 elif column_type == Type.REL.value: 293 _src = row[i]["_src"] 294 _dst = row[i]["_dst"] 295 rels[encode_rel_id(row[i])] = row[i] 296 297 elif column_type == Type.RECURSIVE_REL.value: 298 for node in row[i]["_nodes"]: 299 _id = node["_id"] 300 nodes[(_id["table"], _id["offset"])] = node 301 table_to_label_dict[_id["table"]] = node["_label"] 302 for rel in row[i]["_rels"]: 303 for key in list(rel.keys()): 304 if rel[key] is None: 305 del rel[key] 306 _src = rel["_src"] 307 _dst = rel["_dst"] 308 rels[encode_rel_id(rel)] = rel 309 310 # Add nodes 311 for node in nodes.values(): 312 _id = node["_id"] 313 node_id = node["_label"] + "_" + str(_id["offset"]) 314 if node["_label"] not in table_primary_key_dict: 315 props = self.connection._get_node_property_names(node["_label"]) 316 for prop_name in props: 317 if props[prop_name]["is_primary_key"]: 318 table_primary_key_dict[node["_label"]] = prop_name 319 break 320 node_id = encode_node_id(node, table_primary_key_dict) 321 node[node["_label"]] = True 322 nx_graph.add_node(node_id, **node) 323 324 # Add rels 325 for rel in rels.values(): 326 _src = rel["_src"] 327 _dst = rel["_dst"] 328 src_node = nodes[(_src["table"], _src["offset"])] 329 dst_node = nodes[(_dst["table"], _dst["offset"])] 330 src_id = encode_node_id(src_node, table_primary_key_dict) 331 dst_id = encode_node_id(dst_node, table_primary_key_dict) 332 nx_graph.add_edge(src_id, dst_id, **rel) 333 return nx_graph 334 335 def _get_properties_to_extract(self) -> dict[int, tuple[str, str]]: 336 column_names = self.get_column_names() 337 column_types = self.get_column_data_types() 338 properties_to_extract = {} 339 340 # Iterate over columns and extract nodes and rels, ignoring other columns 341 for i in range(len(column_names)): 342 column_name = column_names[i] 343 column_type = column_types[i] 344 if column_type in [ 345 Type.NODE.value, 346 Type.REL.value, 347 Type.RECURSIVE_REL.value, 348 ]: 349 properties_to_extract[i] = (column_type, column_name) 350 return properties_to_extract 351 352 def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]: # type: ignore[type-arg] 353 """ 354 Converts the nodes and rels in query result into a PyTorch Geometric graph representation 355 torch_geometric.data.Data or torch_geometric.data.HeteroData. 356 357 For node conversion, numerical and boolean properties are directly converted into tensor and 358 stored in Data/HeteroData. For properties cannot be converted into tensor automatically 359 (please refer to the notes below for more detail), they are returned as unconverted_properties. 360 361 For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned 362 as edge_properties. 363 364 Node properties that cannot be converted into tensor automatically: 365 - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted 366 automatically. 367 - If a node property contains a null value, it cannot be converted automatically. 368 - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be 369 converted automatically. 370 - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length 371 is 6 for one node but 5 for another node), it cannot be converted automatically. 372 373 Additional conversion rules: 374 - Columns with data type other than node or rel will be ignored. 375 - Duplicated nodes and rels will be converted only once. 376 377 Returns 378 ------- 379 torch_geometric.data.Data or torch_geometric.data.HeteroData 380 Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties 381 and edge_index tensor. 382 383 dict 384 A dictionary that maps the positional offset of each node in Data/HeteroData to its primary 385 key in the database. 386 387 dict 388 A dictionary contains node properties that cannot be converted into tensor automatically. The 389 order of values for each property is aligned with nodes in Data/HeteroData. 390 391 dict 392 A dictionary contains edge properties. The order of values for each property is aligned with 393 edge_index in Data/HeteroData. 394 """ 395 self.check_for_query_result_close() 396 # Despite we are not using torch_geometric in this file, we need to 397 # import it here to throw an error early if the user does not have 398 # torch_geometric or torch installed. 399 400 converter = TorchGeometricResultConverter(self) 401 return converter.get_as_torch_geometric() 402 403 def get_execution_time(self) -> int: 404 """ 405 Get the time in ms which was required for executing the query. 406 407 Returns 408 ------- 409 double 410 Query execution time as double in ms. 411 412 """ 413 self.check_for_query_result_close() 414 return self._query_result.getExecutionTime() 415 416 def get_compiling_time(self) -> int: 417 """ 418 Get the time in ms which was required for compiling the query. 419 420 Returns 421 ------- 422 double 423 Query compile time as double in ms. 424 425 """ 426 self.check_for_query_result_close() 427 return self._query_result.getCompilingTime() 428 429 def get_num_tuples(self) -> int: 430 """ 431 Get the number of tuples which the query returned. 432 433 Returns 434 ------- 435 int 436 Number of tuples. 437 438 """ 439 self.check_for_query_result_close() 440 return self._query_result.getNumTuples()
QueryResult stores the result of a query execution.
30 def __init__(self, connection: _kuzu.Connection, query_result: _kuzu.QueryResult): # type: ignore[name-defined] 31 """ 32 Parameters 33 ---------- 34 connection : _kuzu.Connection 35 The underlying C++ connection object from pybind11. 36 37 query_result : _kuzu.QueryResult 38 The underlying C++ query result object from pybind11. 39 40 """ 41 self.connection = connection 42 self._query_result = query_result 43 self.is_closed = False
Parameters
- connection (_kuzu.Connection): The underlying C++ connection object from pybind11.
- query_result (_kuzu.QueryResult): The underlying C++ query result object from pybind11.
59 def check_for_query_result_close(self) -> None: 60 """ 61 Check if the query result is closed and raise an exception if it is. 62 63 Raises 64 ------ 65 Exception 66 If the query result is closed. 67 68 """ 69 if self.is_closed: 70 msg = "Query result is closed" 71 raise RuntimeError(msg)
Check if the query result is closed and raise an exception if it is.
Raises
- Exception: If the query result is closed.
73 def has_next(self) -> bool: 74 """ 75 Check if there are more rows in the query result. 76 77 Returns 78 ------- 79 bool 80 True if there are more rows in the query result, False otherwise. 81 82 """ 83 self.check_for_query_result_close() 84 return self._query_result.hasNext()
Check if there are more rows in the query result.
Returns
- bool: True if there are more rows in the query result, False otherwise.
86 def get_next(self) -> list[Any]: 87 """ 88 Get the next row in the query result. 89 90 Returns 91 ------- 92 list 93 Next row in the query result. 94 95 """ 96 self.check_for_query_result_close() 97 return self._query_result.getNext()
Get the next row in the query result.
Returns
- list: Next row in the query result.
99 def close(self) -> None: 100 """Close the query result.""" 101 if not self.is_closed: 102 # Allows the connection to be garbage collected if the query result 103 # is closed manually by the user. 104 self._query_result.close() 105 self.connection = None 106 self.is_closed = True
Close the query result.
108 def get_as_df(self) -> pd.DataFrame: 109 """ 110 Get the query result as a Pandas DataFrame. 111 112 See Also 113 -------- 114 get_as_pl : Get the query result as a Polars DataFrame. 115 get_as_arrow : Get the query result as a PyArrow Table. 116 117 Returns 118 ------- 119 pandas.DataFrame 120 Query result as a Pandas DataFrame. 121 122 """ 123 self.check_for_query_result_close() 124 125 return self._query_result.getAsDF()
Get the query result as a Pandas DataFrame.
See Also
get_as_pl
: Get the query result as a Polars DataFrame.
get_as_arrow
: Get the query result as a PyArrow Table.
Returns
- pandas.DataFrame: Query result as a Pandas DataFrame.
127 def get_as_pl(self) -> pl.DataFrame: 128 """ 129 Get the query result as a Polars DataFrame. 130 131 See Also 132 -------- 133 get_as_df : Get the query result as a Pandas DataFrame. 134 get_as_arrow : Get the query result as a PyArrow Table. 135 136 Returns 137 ------- 138 polars.DataFrame 139 Query result as a Polars DataFrame. 140 """ 141 import polars as pl 142 143 self.check_for_query_result_close() 144 145 # note: polars should always export just a single chunk, 146 # (eg: "-1") otherwise it will just need to rechunk anyway 147 return pl.from_arrow( # type: ignore[return-value] 148 data=self.get_as_arrow(chunk_size=-1), 149 )
Get the query result as a Polars DataFrame.
See Also
get_as_df
: Get the query result as a Pandas DataFrame.
get_as_arrow
: Get the query result as a PyArrow Table.
Returns
- polars.DataFrame: Query result as a Polars DataFrame.
151 def get_as_arrow(self, chunk_size: int | None = None) -> pa.Table: 152 """ 153 Get the query result as a PyArrow Table. 154 155 Parameters 156 ---------- 157 chunk_size : Number of rows to include in each chunk. 158 None 159 The chunk size is adaptive and depends on the number of columns in the query result. 160 -1 or 0 161 The entire result is returned as a single chunk. 162 > 0 163 The chunk size is the number of rows specified. 164 165 See Also 166 -------- 167 get_as_pl : Get the query result as a Polars DataFrame. 168 get_as_df : Get the query result as a Pandas DataFrame. 169 170 Returns 171 ------- 172 pyarrow.Table 173 Query result as a PyArrow Table. 174 """ 175 self.check_for_query_result_close() 176 177 if chunk_size is None: 178 # Adaptive; target 10m total elements in each chunk. 179 # (eg: if we had 10 cols, this would result in a 1m row chunk_size). 180 target_n_elems = 10_000_000 181 chunk_size = max(target_n_elems // len(self.get_column_names()), 10) 182 elif chunk_size <= 0: 183 # No chunking: return the entire result as a single chunk 184 chunk_size = self.get_num_tuples() 185 186 return self._query_result.getAsArrow(chunk_size)
Get the query result as a PyArrow Table.
Parameters
- chunk_size (Number of rows to include in each chunk.): None The chunk size is adaptive and depends on the number of columns in the query result. -1 or 0 The entire result is returned as a single chunk. > 0 The chunk size is the number of rows specified.
See Also
get_as_pl
: Get the query result as a Polars DataFrame.
get_as_df
: Get the query result as a Pandas DataFrame.
Returns
- pyarrow.Table: Query result as a PyArrow Table.
188 def get_column_data_types(self) -> list[str]: 189 """ 190 Get the data types of the columns in the query result. 191 192 Returns 193 ------- 194 list 195 Data types of the columns in the query result. 196 197 """ 198 self.check_for_query_result_close() 199 return self._query_result.getColumnDataTypes()
Get the data types of the columns in the query result.
Returns
- list: Data types of the columns in the query result.
201 def get_column_names(self) -> list[str]: 202 """ 203 Get the names of the columns in the query result. 204 205 Returns 206 ------- 207 list 208 Names of the columns in the query result. 209 210 """ 211 self.check_for_query_result_close() 212 return self._query_result.getColumnNames()
Get the names of the columns in the query result.
Returns
- list: Names of the columns in the query result.
214 def get_schema(self) -> dict[str, str]: 215 """ 216 Get the column schema of the query result. 217 218 Returns 219 ------- 220 dict 221 Schema of the query result. 222 223 """ 224 self.check_for_query_result_close() 225 return dict( 226 zip( 227 self._query_result.getColumnNames(), 228 self._query_result.getColumnDataTypes(), 229 ) 230 )
Get the column schema of the query result.
Returns
- dict: Schema of the query result.
232 def reset_iterator(self) -> None: 233 """Reset the iterator of the query result.""" 234 self.check_for_query_result_close() 235 self._query_result.resetIterator()
Reset the iterator of the query result.
237 def get_as_networkx( 238 self, 239 directed: bool = True, # noqa: FBT001 240 ) -> nx.MultiGraph | nx.MultiDiGraph: 241 """ 242 Convert the nodes and rels in query result into a NetworkX directed or undirected graph 243 with the following rules: 244 Columns with data type other than node or rel will be ignored. 245 Duplicated nodes and rels will be converted only once. 246 247 Parameters 248 ---------- 249 directed : bool 250 Whether the graph should be directed. Defaults to True. 251 252 Returns 253 ------- 254 networkx.MultiDiGraph or networkx.MultiGraph 255 Query result as a NetworkX graph. 256 257 """ 258 self.check_for_query_result_close() 259 import networkx as nx 260 261 nx_graph = nx.MultiDiGraph() if directed else nx.MultiGraph() 262 properties_to_extract = self._get_properties_to_extract() 263 264 self.reset_iterator() 265 266 nodes = {} 267 rels = {} 268 table_to_label_dict = {} 269 table_primary_key_dict = {} 270 271 def encode_node_id(node: dict[str, Any], table_primary_key_dict: dict[str, Any]) -> str: 272 node_label = node["_label"] 273 return f"{node_label}_{node[table_primary_key_dict[node_label]]!s}" 274 275 def encode_rel_id(rel: dict[str, Any]) -> tuple[int, int]: 276 return rel["_id"]["table"], rel["_id"]["offset"] 277 278 # De-duplicate nodes and rels 279 while self.has_next(): 280 row = self.get_next() 281 for i in properties_to_extract: 282 # Skip empty nodes and rels, which may be returned by 283 # OPTIONAL MATCH 284 if row[i] is None or row[i] == {}: 285 continue 286 column_type, _ = properties_to_extract[i] 287 if column_type == Type.NODE.value: 288 _id = row[i]["_id"] 289 nodes[(_id["table"], _id["offset"])] = row[i] 290 table_to_label_dict[_id["table"]] = row[i]["_label"] 291 292 elif column_type == Type.REL.value: 293 _src = row[i]["_src"] 294 _dst = row[i]["_dst"] 295 rels[encode_rel_id(row[i])] = row[i] 296 297 elif column_type == Type.RECURSIVE_REL.value: 298 for node in row[i]["_nodes"]: 299 _id = node["_id"] 300 nodes[(_id["table"], _id["offset"])] = node 301 table_to_label_dict[_id["table"]] = node["_label"] 302 for rel in row[i]["_rels"]: 303 for key in list(rel.keys()): 304 if rel[key] is None: 305 del rel[key] 306 _src = rel["_src"] 307 _dst = rel["_dst"] 308 rels[encode_rel_id(rel)] = rel 309 310 # Add nodes 311 for node in nodes.values(): 312 _id = node["_id"] 313 node_id = node["_label"] + "_" + str(_id["offset"]) 314 if node["_label"] not in table_primary_key_dict: 315 props = self.connection._get_node_property_names(node["_label"]) 316 for prop_name in props: 317 if props[prop_name]["is_primary_key"]: 318 table_primary_key_dict[node["_label"]] = prop_name 319 break 320 node_id = encode_node_id(node, table_primary_key_dict) 321 node[node["_label"]] = True 322 nx_graph.add_node(node_id, **node) 323 324 # Add rels 325 for rel in rels.values(): 326 _src = rel["_src"] 327 _dst = rel["_dst"] 328 src_node = nodes[(_src["table"], _src["offset"])] 329 dst_node = nodes[(_dst["table"], _dst["offset"])] 330 src_id = encode_node_id(src_node, table_primary_key_dict) 331 dst_id = encode_node_id(dst_node, table_primary_key_dict) 332 nx_graph.add_edge(src_id, dst_id, **rel) 333 return nx_graph
Convert the nodes and rels in query result into a NetworkX directed or undirected graph with the following rules: Columns with data type other than node or rel will be ignored. Duplicated nodes and rels will be converted only once.
Parameters
- directed (bool): Whether the graph should be directed. Defaults to True.
Returns
- networkx.MultiDiGraph or networkx.MultiGraph: Query result as a NetworkX graph.
352 def get_as_torch_geometric(self) -> tuple[geo.Data | geo.HeteroData, dict, dict, dict]: # type: ignore[type-arg] 353 """ 354 Converts the nodes and rels in query result into a PyTorch Geometric graph representation 355 torch_geometric.data.Data or torch_geometric.data.HeteroData. 356 357 For node conversion, numerical and boolean properties are directly converted into tensor and 358 stored in Data/HeteroData. For properties cannot be converted into tensor automatically 359 (please refer to the notes below for more detail), they are returned as unconverted_properties. 360 361 For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned 362 as edge_properties. 363 364 Node properties that cannot be converted into tensor automatically: 365 - If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted 366 automatically. 367 - If a node property contains a null value, it cannot be converted automatically. 368 - If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be 369 converted automatically. 370 - If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length 371 is 6 for one node but 5 for another node), it cannot be converted automatically. 372 373 Additional conversion rules: 374 - Columns with data type other than node or rel will be ignored. 375 - Duplicated nodes and rels will be converted only once. 376 377 Returns 378 ------- 379 torch_geometric.data.Data or torch_geometric.data.HeteroData 380 Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties 381 and edge_index tensor. 382 383 dict 384 A dictionary that maps the positional offset of each node in Data/HeteroData to its primary 385 key in the database. 386 387 dict 388 A dictionary contains node properties that cannot be converted into tensor automatically. The 389 order of values for each property is aligned with nodes in Data/HeteroData. 390 391 dict 392 A dictionary contains edge properties. The order of values for each property is aligned with 393 edge_index in Data/HeteroData. 394 """ 395 self.check_for_query_result_close() 396 # Despite we are not using torch_geometric in this file, we need to 397 # import it here to throw an error early if the user does not have 398 # torch_geometric or torch installed. 399 400 converter = TorchGeometricResultConverter(self) 401 return converter.get_as_torch_geometric()
Converts the nodes and rels in query result into a PyTorch Geometric graph representation torch_geometric.data.Data or torch_geometric.data.HeteroData.
For node conversion, numerical and boolean properties are directly converted into tensor and stored in Data/HeteroData. For properties cannot be converted into tensor automatically (please refer to the notes below for more detail), they are returned as unconverted_properties.
For rel conversion, rel is converted into edge_index tensor director. Edge properties are returned as edge_properties.
Node properties that cannot be converted into tensor automatically:
- If the type of a node property is not one of INT64, DOUBLE, or BOOL, it cannot be converted automatically.
- If a node property contains a null value, it cannot be converted automatically.
- If a node property contains a nested list of variable length (e.g. [[1,2],[3]]), it cannot be converted automatically.
- If a node property is a list or nested list, but the shape is inconsistent (e.g. the list length is 6 for one node but 5 for another node), it cannot be converted automatically.
Additional conversion rules:
- Columns with data type other than node or rel will be ignored.
- Duplicated nodes and rels will be converted only once.
Returns
- torch_geometric.data.Data or torch_geometric.data.HeteroData: Query result as a PyTorch Geometric graph. Containing numeric or boolean node properties and edge_index tensor.
- dict: A dictionary that maps the positional offset of each node in Data/HeteroData to its primary key in the database.
- dict: A dictionary contains node properties that cannot be converted into tensor automatically. The order of values for each property is aligned with nodes in Data/HeteroData.
- dict: A dictionary contains edge properties. The order of values for each property is aligned with edge_index in Data/HeteroData.
403 def get_execution_time(self) -> int: 404 """ 405 Get the time in ms which was required for executing the query. 406 407 Returns 408 ------- 409 double 410 Query execution time as double in ms. 411 412 """ 413 self.check_for_query_result_close() 414 return self._query_result.getExecutionTime()
Get the time in ms which was required for executing the query.
Returns
- double: Query execution time as double in ms.
416 def get_compiling_time(self) -> int: 417 """ 418 Get the time in ms which was required for compiling the query. 419 420 Returns 421 ------- 422 double 423 Query compile time as double in ms. 424 425 """ 426 self.check_for_query_result_close() 427 return self._query_result.getCompilingTime()
Get the time in ms which was required for compiling the query.
Returns
- double: Query compile time as double in ms.
429 def get_num_tuples(self) -> int: 430 """ 431 Get the number of tuples which the query returned. 432 433 Returns 434 ------- 435 int 436 Number of tuples. 437 438 """ 439 self.check_for_query_result_close() 440 return self._query_result.getNumTuples()
Get the number of tuples which the query returned.
Returns
- int: Number of tuples.
5class Type(Enum): 6 """The type of a value in the database.""" 7 8 ANY = "ANY" 9 NODE = "NODE" 10 REL = "REL" 11 RECURSIVE_REL = "RECURSIVE_REL" 12 SERIAL = "SERIAL" 13 BOOL = "BOOL" 14 INT64 = "INT64" 15 INT32 = "INT32" 16 INT16 = "INT16" 17 INT8 = "INT8" 18 UINT64 = "UINT64" 19 UINT32 = "UINT32" 20 UINT16 = "UINT16" 21 UINT8 = "UINT8" 22 INT128 = "INT128" 23 DOUBLE = "DOUBLE" 24 FLOAT = "FLOAT" 25 DATE = "DATE" 26 TIMESTAMP = "TIMESTAMP" 27 TIMSTAMP_TZ = "TIMESTAMP_TZ" 28 TIMESTAMP_NS = "TIMESTAMP_NS" 29 TIMESTAMP_MS = "TIMESTAMP_MS" 30 TIMESTAMP_SEC = "TIMESTAMP_SEC" 31 INTERVAL = "INTERVAL" 32 INTERNAL_ID = "INTERNAL_ID" 33 STRING = "STRING" 34 BLOB = "BLOB" 35 UUID = "UUID" 36 LIST = "LIST" 37 ARRAY = "ARRAY" 38 STRUCT = "STRUCT" 39 MAP = "MAP" 40 UNION = "UNION"
The type of a value in the database.