| Viewing file:  assertsql.py (13.27 KB)      -rw-r--r-- Select action/file-type:
 
  (+) |  (+) |  (+) | Code (+) | Session (+) |  (+) | SDB (+) |  (+) |  (+) |  (+) |  (+) |  (+) | 
 
# testing/assertsql.py# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
 # <see AUTHORS file>
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 
 import collections
 import contextlib
 import re
 
 from .. import event
 from .. import util
 from ..engine import url
 from ..engine.default import DefaultDialect
 from ..engine.util import _distill_params
 from ..schema import _DDLCompiles
 
 
 class AssertRule(object):
 
 is_consumed = False
 errormessage = None
 consume_statement = True
 
 def process_statement(self, execute_observed):
 pass
 
 def no_more_statements(self):
 assert False, (
 "All statements are complete, but pending "
 "assertion rules remain"
 )
 
 
 class SQLMatchRule(AssertRule):
 pass
 
 
 class CursorSQL(SQLMatchRule):
 consume_statement = False
 
 def __init__(self, statement, params=None):
 self.statement = statement
 self.params = params
 
 def process_statement(self, execute_observed):
 stmt = execute_observed.statements[0]
 if self.statement != stmt.statement or (
 self.params is not None and self.params != stmt.parameters
 ):
 self.errormessage = (
 "Testing for exact SQL %s parameters %s received %s %s"
 % (
 self.statement,
 self.params,
 stmt.statement,
 stmt.parameters,
 )
 )
 else:
 execute_observed.statements.pop(0)
 self.is_consumed = True
 if not execute_observed.statements:
 self.consume_statement = True
 
 
 class CompiledSQL(SQLMatchRule):
 def __init__(self, statement, params=None, dialect="default"):
 self.statement = statement
 self.params = params
 self.dialect = dialect
 
 def _compare_sql(self, execute_observed, received_statement):
 stmt = re.sub(r"[\n\t]", "", self.statement)
 return received_statement == stmt
 
 def _compile_dialect(self, execute_observed):
 if self.dialect == "default":
 return DefaultDialect()
 else:
 # ugh
 if self.dialect == "postgresql":
 params = {"implicit_returning": True}
 else:
 params = {}
 return url.URL(self.dialect).get_dialect()(**params)
 
 def _received_statement(self, execute_observed):
 """reconstruct the statement and params in terms
 of a target dialect, which for CompiledSQL is just DefaultDialect."""
 
 context = execute_observed.context
 compare_dialect = self._compile_dialect(execute_observed)
 if isinstance(context.compiled.statement, _DDLCompiles):
 compiled = context.compiled.statement.compile(
 dialect=compare_dialect,
 schema_translate_map=context.execution_options.get(
 "schema_translate_map"
 ),
 )
 else:
 compiled = context.compiled.statement.compile(
 dialect=compare_dialect,
 column_keys=context.compiled.column_keys,
 inline=context.compiled.inline,
 schema_translate_map=context.execution_options.get(
 "schema_translate_map"
 ),
 )
 _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
 parameters = execute_observed.parameters
 
 if not parameters:
 _received_parameters = [compiled.construct_params()]
 else:
 _received_parameters = [
 compiled.construct_params(m) for m in parameters
 ]
 
 return _received_statement, _received_parameters
 
 def process_statement(self, execute_observed):
 context = execute_observed.context
 
 _received_statement, _received_parameters = self._received_statement(
 execute_observed
 )
 params = self._all_params(context)
 
 equivalent = self._compare_sql(execute_observed, _received_statement)
 
 if equivalent:
 if params is not None:
 all_params = list(params)
 all_received = list(_received_parameters)
 while all_params and all_received:
 param = dict(all_params.pop(0))
 
 for idx, received in enumerate(list(all_received)):
 # do a positive compare only
 for param_key in param:
 # a key in param did not match current
 # 'received'
 if (
 param_key not in received
 or received[param_key] != param[param_key]
 ):
 break
 else:
 # all keys in param matched 'received';
 # onto next param
 del all_received[idx]
 break
 else:
 # param did not match any entry
 # in all_received
 equivalent = False
 break
 if all_params or all_received:
 equivalent = False
 
 if equivalent:
 self.is_consumed = True
 self.errormessage = None
 else:
 self.errormessage = self._failure_message(params) % {
 "received_statement": _received_statement,
 "received_parameters": _received_parameters,
 }
 
 def _all_params(self, context):
 if self.params:
 if util.callable(self.params):
 params = self.params(context)
 else:
 params = self.params
 if not isinstance(params, list):
 params = [params]
 return params
 else:
 return None
 
 def _failure_message(self, expected_params):
 return (
 "Testing for compiled statement %r partial params %s, "
 "received %%(received_statement)r with params "
 "%%(received_parameters)r"
 % (
 self.statement.replace("%", "%%"),
 repr(expected_params).replace("%", "%%"),
 )
 )
 
 
 class RegexSQL(CompiledSQL):
 def __init__(self, regex, params=None, dialect="default"):
 SQLMatchRule.__init__(self)
 self.regex = re.compile(regex)
 self.orig_regex = regex
 self.params = params
 self.dialect = dialect
 
 def _failure_message(self, expected_params):
 return (
 "Testing for compiled statement ~%r partial params %s, "
 "received %%(received_statement)r with params "
 "%%(received_parameters)r"
 % (
 self.orig_regex.replace("%", "%%"),
 repr(expected_params).replace("%", "%%"),
 )
 )
 
 def _compare_sql(self, execute_observed, received_statement):
 return bool(self.regex.match(received_statement))
 
 
 class DialectSQL(CompiledSQL):
 def _compile_dialect(self, execute_observed):
 return execute_observed.context.dialect
 
 def _compare_no_space(self, real_stmt, received_stmt):
 stmt = re.sub(r"[\n\t]", "", real_stmt)
 return received_stmt == stmt
 
 def _received_statement(self, execute_observed):
 received_stmt, received_params = super(
 DialectSQL, self
 )._received_statement(execute_observed)
 
 # TODO: why do we need this part?
 for real_stmt in execute_observed.statements:
 if self._compare_no_space(real_stmt.statement, received_stmt):
 break
 else:
 raise AssertionError(
 "Can't locate compiled statement %r in list of "
 "statements actually invoked" % received_stmt
 )
 
 return received_stmt, execute_observed.context.compiled_parameters
 
 def _compare_sql(self, execute_observed, received_statement):
 stmt = re.sub(r"[\n\t]", "", self.statement)
 # convert our comparison statement to have the
 # paramstyle of the received
 paramstyle = execute_observed.context.dialect.paramstyle
 if paramstyle == "pyformat":
 stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
 else:
 # positional params
 repl = None
 if paramstyle == "qmark":
 repl = "?"
 elif paramstyle == "format":
 repl = r"%s"
 elif paramstyle == "numeric":
 repl = None
 stmt = re.sub(r":([\w_]+)", repl, stmt)
 
 return received_statement == stmt
 
 
 class CountStatements(AssertRule):
 def __init__(self, count):
 self.count = count
 self._statement_count = 0
 
 def process_statement(self, execute_observed):
 self._statement_count += 1
 
 def no_more_statements(self):
 if self.count != self._statement_count:
 assert False, "desired statement count %d does not match %d" % (
 self.count,
 self._statement_count,
 )
 
 
 class AllOf(AssertRule):
 def __init__(self, *rules):
 self.rules = set(rules)
 
 def process_statement(self, execute_observed):
 for rule in list(self.rules):
 rule.errormessage = None
 rule.process_statement(execute_observed)
 if rule.is_consumed:
 self.rules.discard(rule)
 if not self.rules:
 self.is_consumed = True
 break
 elif not rule.errormessage:
 # rule is not done yet
 self.errormessage = None
 break
 else:
 self.errormessage = list(self.rules)[0].errormessage
 
 
 class EachOf(AssertRule):
 def __init__(self, *rules):
 self.rules = list(rules)
 
 def process_statement(self, execute_observed):
 while self.rules:
 rule = self.rules[0]
 rule.process_statement(execute_observed)
 if rule.is_consumed:
 self.rules.pop(0)
 elif rule.errormessage:
 self.errormessage = rule.errormessage
 if rule.consume_statement:
 break
 
 if not self.rules:
 self.is_consumed = True
 
 def no_more_statements(self):
 if self.rules and not self.rules[0].is_consumed:
 self.rules[0].no_more_statements()
 elif self.rules:
 super(EachOf, self).no_more_statements()
 
 
 class Or(AllOf):
 def process_statement(self, execute_observed):
 for rule in self.rules:
 rule.process_statement(execute_observed)
 if rule.is_consumed:
 self.is_consumed = True
 break
 else:
 self.errormessage = list(self.rules)[0].errormessage
 
 
 class SQLExecuteObserved(object):
 def __init__(self, context, clauseelement, multiparams, params):
 self.context = context
 self.clauseelement = clauseelement
 self.parameters = _distill_params(multiparams, params)
 self.statements = []
 
 
 class SQLCursorExecuteObserved(
 collections.namedtuple(
 "SQLCursorExecuteObserved",
 ["statement", "parameters", "context", "executemany"],
 )
 ):
 pass
 
 
 class SQLAsserter(object):
 def __init__(self):
 self.accumulated = []
 
 def _close(self):
 self._final = self.accumulated
 del self.accumulated
 
 def assert_(self, *rules):
 rule = EachOf(*rules)
 
 observed = list(self._final)
 while observed:
 statement = observed.pop(0)
 rule.process_statement(statement)
 if rule.is_consumed:
 break
 elif rule.errormessage:
 assert False, rule.errormessage
 if observed:
 assert False, "Additional SQL statements remain"
 elif not rule.is_consumed:
 rule.no_more_statements()
 
 
 @contextlib.contextmanager
 def assert_engine(engine):
 asserter = SQLAsserter()
 
 orig = []
 
 @event.listens_for(engine, "before_execute")
 def connection_execute(conn, clauseelement, multiparams, params):
 # grab the original statement + params before any cursor
 # execution
 orig[:] = clauseelement, multiparams, params
 
 @event.listens_for(engine, "after_cursor_execute")
 def cursor_execute(
 conn, cursor, statement, parameters, context, executemany
 ):
 if not context:
 return
 # then grab real cursor statements and associate them all
 # around a single context
 if (
 asserter.accumulated
 and asserter.accumulated[-1].context is context
 ):
 obs = asserter.accumulated[-1]
 else:
 obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
 asserter.accumulated.append(obs)
 obs.statements.append(
 SQLCursorExecuteObserved(
 statement, parameters, context, executemany
 )
 )
 
 try:
 yield asserter
 finally:
 event.remove(engine, "after_cursor_execute", cursor_execute)
 event.remove(engine, "before_execute", connection_execute)
 asserter._close()
 
 |