Source code for airflow.providers.ydb.hooks._vendor.dbapi.cursor

import collections.abc
import dataclasses
import functools
import hashlib
import itertools
import logging
import posixpath
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union

import ydb
import ydb.aio
from sqlalchemy import util

from .errors import (
    DatabaseError,
    DataError,
    IntegrityError,
    InternalError,
    NotSupportedError,
    OperationalError,
    ProgrammingError,
)

[docs]logger = logging.getLogger(__name__)
[docs]def get_column_type(type_obj: Any) -> str: return str(ydb.convert.type_to_native(type_obj))
@dataclasses.dataclass
[docs]class YdbQuery:
[docs] yql_text: str
[docs] parameters_types: Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]] = dataclasses.field( default_factory=dict )
[docs] is_ddl: bool = False
def _handle_ydb_errors(func): @functools.wraps(func) def wrapper(*args, **kwargs): try: return func(*args, **kwargs) except (ydb.issues.AlreadyExists, ydb.issues.PreconditionFailed) as e: raise IntegrityError(e.message, original_error=e) from e except (ydb.issues.Unsupported, ydb.issues.Unimplemented) as e: raise NotSupportedError(e.message, original_error=e) from e except (ydb.issues.BadRequest, ydb.issues.SchemeError) as e: raise ProgrammingError(e.message, original_error=e) from e except ( ydb.issues.TruncatedResponseError, ydb.issues.ConnectionError, ydb.issues.Aborted, ydb.issues.Unavailable, ydb.issues.Overloaded, ydb.issues.Undetermined, ydb.issues.Timeout, ydb.issues.Cancelled, ydb.issues.SessionBusy, ydb.issues.SessionExpired, ydb.issues.SessionPoolEmpty, ) as e: raise OperationalError(e.message, original_error=e) from e except ydb.issues.GenericError as e: raise DataError(e.message, original_error=e) from e except ydb.issues.InternalError as e: raise InternalError(e.message, original_error=e) from e except ydb.Error as e: raise DatabaseError(e.message, original_error=e) from e except Exception as e: raise DatabaseError("Failed to execute query") from e return wrapper
[docs]class Cursor: def __init__( self, session_pool: Union[ydb.SessionPool, ydb.aio.SessionPool], tx_mode: ydb.AbstractTransactionModeBuilder, tx_context: Optional[ydb.BaseTxContext] = None, table_path_prefix: str = "", ): self.session_pool = session_pool self.tx_mode = tx_mode self.tx_context = tx_context self.description = None self.arraysize = 1 self.rows = None self._rows_prefetched = None self.root_directory = table_path_prefix @_handle_ydb_errors
[docs] def describe_table(self, abs_table_path: str) -> ydb.TableDescription: return self._retry_operation_in_pool(self._describe_table, abs_table_path)
[docs] def check_exists(self, abs_table_path: str) -> bool: try: self._retry_operation_in_pool(self._describe_path, abs_table_path) return True except ydb.SchemeError: return False
@_handle_ydb_errors
[docs] def get_table_names(self, abs_dir_path: str) -> List[str]: directory: ydb.Directory = self._retry_operation_in_pool(self._list_directory, abs_dir_path) result = [] for child in directory.children: child_abs_path = posixpath.join(abs_dir_path, child.name) if child.is_table(): result.append(child_abs_path) elif child.is_directory() and not child.name.startswith("."): result.extend(self.get_table_names(child_abs_path)) return result
[docs] def execute(self, operation: YdbQuery, parameters: Optional[Mapping[str, Any]] = None): query = self._get_ydb_query(operation) logger.info("execute sql: %s, params: %s", query, parameters) if operation.is_ddl: chunks = self._execute_ddl(query) else: chunks = self._execute_dml(query, parameters) rows = self._rows_iterable(chunks) # Prefetch the description: try: first_row = next(rows) except StopIteration: pass else: rows = itertools.chain((first_row,), rows) if self.rows is not None: rows = itertools.chain(self.rows, rows) self.rows = rows
def _get_ydb_query(self, operation: YdbQuery) -> Union[ydb.DataQuery, str]: pragma = "" if self.root_directory: pragma = f'PRAGMA TablePathPrefix = "{self.root_directory}";\n' yql_with_pragma = pragma + operation.yql_text if operation.is_ddl or not operation.parameters_types: return yql_with_pragma return self._make_data_query(yql_with_pragma, operation.parameters_types) def _make_data_query( self, yql_text: str, parameters_types: Dict[str, Union[ydb.PrimitiveType, ydb.AbstractTypeBuilder]], ) -> ydb.DataQuery: """ ydb.DataQuery uses hashed SQL text as cache key, which may cause issues if parameters change type within the same session, so we include parameter types to the key to prevent false positive cache hit. """ sorted_parameters = sorted(parameters_types.items()) # dict keys are unique, so the sorting is stable yql_with_params = yql_text + "".join([k + str(v) for k, v in sorted_parameters]) name = hashlib.sha256(yql_with_params.encode("utf-8")).hexdigest() return ydb.DataQuery(yql_text, parameters_types, name=name) @_handle_ydb_errors def _execute_dml( self, query: Union[ydb.DataQuery, str], parameters: Optional[Mapping[str, Any]] = None ) -> ydb.convert.ResultSets: prepared_query = query if isinstance(query, str) and parameters: if self.tx_context: prepared_query = self._run_operation_in_session(self._prepare, query) else: prepared_query = self._retry_operation_in_pool(self._prepare, query) if self.tx_context: return self._run_operation_in_tx(self._execute_in_tx, prepared_query, parameters) return self._retry_operation_in_pool(self._execute_in_session, self.tx_mode, prepared_query, parameters) @_handle_ydb_errors def _execute_ddl(self, query: str) -> ydb.convert.ResultSets: return self._retry_operation_in_pool(self._execute_scheme, query) @staticmethod def _execute_scheme(session: ydb.Session, query: str) -> ydb.convert.ResultSets: return session.execute_scheme(query) @staticmethod def _describe_table(session: ydb.Session, abs_table_path: str) -> ydb.TableDescription: return session.describe_table(abs_table_path) @staticmethod def _describe_path(session: ydb.Session, table_path: str) -> ydb.SchemeEntry: return session._driver.scheme_client.describe_path(table_path) @staticmethod def _list_directory(session: ydb.Session, abs_dir_path: str) -> ydb.Directory: return session._driver.scheme_client.list_directory(abs_dir_path) @staticmethod def _prepare(session: ydb.Session, query: str) -> ydb.DataQuery: return session.prepare(query) @staticmethod def _execute_in_tx( tx_context: ydb.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] ) -> ydb.convert.ResultSets: return tx_context.execute(prepared_query, parameters, commit_tx=False) @staticmethod def _execute_in_session( session: ydb.Session, tx_mode: ydb.AbstractTransactionModeBuilder, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]], ) -> ydb.convert.ResultSets: return session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) def _run_operation_in_tx(self, callee: collections.abc.Callable, *args, **kwargs): return callee(self.tx_context, *args, **kwargs) def _run_operation_in_session(self, callee: collections.abc.Callable, *args, **kwargs): return callee(self.tx_context.session, *args, **kwargs) def _retry_operation_in_pool(self, callee: collections.abc.Callable, *args, **kwargs): return self.session_pool.retry_operation_sync(callee, None, *args, **kwargs) def _rows_iterable(self, chunks_iterable: ydb.convert.ResultSets): try: for chunk in chunks_iterable: self.description = [ ( col.name, get_column_type(col.type), None, None, None, None, None, ) for col in chunk.columns ] for row in chunk.rows: # returns tuple to be compatible with SqlAlchemy and because # of this PEP to return a sequence: https://www.python.org/dev/peps/pep-0249/#fetchmany yield row[::] except ydb.Error as e: raise DatabaseError(e.message, original_error=e) from e def _ensure_prefetched(self): if self.rows is not None and self._rows_prefetched is None: self._rows_prefetched = list(self.rows) self.rows = iter(self._rows_prefetched) return self._rows_prefetched
[docs] def executemany(self, operation: YdbQuery, seq_of_parameters: Optional[Sequence[Mapping[str, Any]]]): for parameters in seq_of_parameters: self.execute(operation, parameters)
[docs] def executescript(self, script): return self.execute(script)
[docs] def fetchone(self): return next(self.rows or [], None)
[docs] def fetchmany(self, size=None): return list(itertools.islice(self.rows, size or self.arraysize))
[docs] def fetchall(self): return list(self.rows)
[docs] def nextset(self): self.fetchall()
[docs] def setinputsizes(self, sizes): pass
[docs] def setoutputsize(self, column=None): pass
[docs] def close(self): self.rows = None self._rows_prefetched = None
@property
[docs] def rowcount(self): return len(self._ensure_prefetched())
[docs]class AsyncCursor(Cursor): _await = staticmethod(util.await_only) @staticmethod async def _describe_table(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.TableDescription: return await session.describe_table(abs_table_path) @staticmethod async def _describe_path(session: ydb.aio.table.Session, abs_table_path: str) -> ydb.SchemeEntry: return await session._driver.scheme_client.describe_path(abs_table_path) @staticmethod async def _list_directory(session: ydb.aio.table.Session, abs_dir_path: str) -> ydb.Directory: return await session._driver.scheme_client.list_directory(abs_dir_path) @staticmethod async def _execute_scheme(session: ydb.aio.table.Session, query: str) -> ydb.convert.ResultSets: return await session.execute_scheme(query) @staticmethod async def _prepare(session: ydb.aio.table.Session, query: str) -> ydb.DataQuery: return await session.prepare(query) @staticmethod async def _execute_in_tx( tx_context: ydb.aio.table.TxContext, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]] ) -> ydb.convert.ResultSets: return await tx_context.execute(prepared_query, parameters, commit_tx=False) @staticmethod async def _execute_in_session( session: ydb.aio.table.Session, tx_mode: ydb.AbstractTransactionModeBuilder, prepared_query: ydb.DataQuery, parameters: Optional[Mapping[str, Any]], ) -> ydb.convert.ResultSets: return await session.transaction(tx_mode).execute(prepared_query, parameters, commit_tx=True) def _run_operation_in_tx(self, callee: collections.abc.Coroutine, *args, **kwargs): return self._await(callee(self.tx_context, *args, **kwargs)) def _run_operation_in_session(self, callee: collections.abc.Coroutine, *args, **kwargs): return self._await(callee(self.tx_context.session, *args, **kwargs)) def _retry_operation_in_pool(self, callee: collections.abc.Coroutine, *args, **kwargs): return self._await(self.session_pool.retry_operation(callee, *args, **kwargs))

Was this entry helpful?