753 lines
23 KiB
Python
Raw Normal View History

2021-01-02 23:12:27 -06:00
# testing/assertions.py
2021-03-27 16:21:31 -05:00
# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
2021-01-02 23:12:27 -06:00
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
from __future__ import absolute_import
import contextlib
import re
import sys
import warnings
from . import assertsql
from . import config
2021-03-27 16:21:31 -05:00
from . import engines
2021-01-02 23:12:27 -06:00
from . import mock
from .exclusions import db_spec
from .util import fail
from .. import exc as sa_exc
from .. import schema
from .. import sql
from .. import types as sqltypes
from .. import util
from ..engine import default
from ..engine import url
2021-03-27 16:21:31 -05:00
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
2021-01-02 23:12:27 -06:00
from ..util import compat
from ..util import decorator
def expect_warnings(*messages, **kw):
"""Context manager which expects one or more warnings.
2021-03-27 16:21:31 -05:00
With no arguments, squelches all SAWarning and RemovedIn20Warning emitted via
2021-01-02 23:12:27 -06:00
sqlalchemy.util.warn and sqlalchemy.util.warn_limited. Otherwise
pass string expressions that will match selected warnings via regex;
all non-matching warnings are sent through.
The expect version **asserts** that the warnings were in fact seen.
Note that the test suite sets SAWarning warnings to raise exceptions.
2021-03-27 16:21:31 -05:00
""" # noqa
return _expect_warnings(
(sa_exc.RemovedIn20Warning, sa_exc.SAWarning), messages, **kw
)
2021-01-02 23:12:27 -06:00
@contextlib.contextmanager
def expect_warnings_on(db, *messages, **kw):
"""Context manager which expects one or more warnings on specific
dialects.
The expect version **asserts** that the warnings were in fact seen.
"""
spec = db_spec(db)
if isinstance(db, util.string_types) and not spec(config._current):
yield
else:
with expect_warnings(*messages, **kw):
yield
def emits_warning(*messages):
"""Decorator form of expect_warnings().
Note that emits_warning does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_warnings(assert_=False, *messages):
return fn(*args, **kw)
return decorate
def expect_deprecated(*messages, **kw):
return _expect_warnings(sa_exc.SADeprecationWarning, messages, **kw)
2021-03-27 16:21:31 -05:00
def expect_deprecated_20(*messages, **kw):
return _expect_warnings(sa_exc.RemovedIn20Warning, messages, **kw)
2021-01-02 23:12:27 -06:00
def emits_warning_on(db, *messages):
"""Mark a test as emitting a warning on a specific dialect.
With no arguments, squelches all SAWarning failures. Or pass one or more
strings; these will be matched to the root of the warning description by
warnings.filterwarnings().
Note that emits_warning_on does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_warnings_on(db, assert_=False, *messages):
return fn(*args, **kw)
return decorate
def uses_deprecated(*messages):
"""Mark a test as immune from fatal deprecation warnings.
With no arguments, squelches all SADeprecationWarning failures.
Or pass one or more strings; these will be matched to the root
of the warning description by warnings.filterwarnings().
As a special case, you may pass a function name prefixed with //
and it will be re-written as needed to match the standard warning
verbiage emitted by the sqlalchemy.util.deprecated decorator.
Note that uses_deprecated does **not** assert that the warnings
were in fact seen.
"""
@decorator
def decorate(fn, *args, **kw):
with expect_deprecated(*messages, assert_=False):
return fn(*args, **kw)
return decorate
@contextlib.contextmanager
def _expect_warnings(
2021-03-27 16:21:31 -05:00
exc_cls,
messages,
regex=True,
assert_=True,
py2konly=False,
raise_on_any_unexpected=False,
2021-01-02 23:12:27 -06:00
):
if regex:
filters = [re.compile(msg, re.I | re.S) for msg in messages]
else:
filters = messages
seen = set(filters)
2021-03-27 16:21:31 -05:00
if raise_on_any_unexpected:
def real_warn(msg, *arg, **kw):
raise AssertionError("Got unexpected warning: %r" % msg)
else:
real_warn = warnings.warn
2021-01-02 23:12:27 -06:00
def our_warn(msg, *arg, **kw):
if isinstance(msg, exc_cls):
2021-03-27 16:21:31 -05:00
exception = type(msg)
msg = str(msg)
2021-01-02 23:12:27 -06:00
elif arg:
exception = arg[0]
else:
exception = None
2021-03-27 16:21:31 -05:00
2021-01-02 23:12:27 -06:00
if not exception or not issubclass(exception, exc_cls):
return real_warn(msg, *arg, **kw)
2021-03-27 16:21:31 -05:00
if not filters and not raise_on_any_unexpected:
2021-01-02 23:12:27 -06:00
return
for filter_ in filters:
if (regex and filter_.match(msg)) or (
not regex and filter_ == msg
):
seen.discard(filter_)
break
else:
real_warn(msg, *arg, **kw)
2021-03-27 16:21:31 -05:00
with mock.patch("warnings.warn", our_warn), mock.patch(
"sqlalchemy.util.SQLALCHEMY_WARN_20", True
), mock.patch(
"sqlalchemy.util.deprecations.SQLALCHEMY_WARN_20", True
), mock.patch(
"sqlalchemy.engine.row.LegacyRow._default_key_style", 2
):
2021-01-02 23:12:27 -06:00
yield
if assert_ and (not py2konly or not compat.py3k):
assert not seen, "Warnings were not seen: %s" % ", ".join(
"%r" % (s.pattern if regex else s) for s in seen
)
def global_cleanup_assertions():
"""Check things that have to be finalized at the end of a test suite.
Hardcoded at the moment, a modular system can be built here
to support things like PG prepared transactions, tables all
dropped, etc.
"""
_assert_no_stray_pool_connections()
def _assert_no_stray_pool_connections():
2021-03-27 16:21:31 -05:00
engines.testing_reaper.assert_all_closed()
2021-01-02 23:12:27 -06:00
def eq_regex(a, b, msg=None):
assert re.match(b, a), msg or "%r !~ %r" % (a, b)
def eq_(a, b, msg=None):
"""Assert a == b, with repr messaging on failure."""
assert a == b, msg or "%r != %r" % (a, b)
def ne_(a, b, msg=None):
"""Assert a != b, with repr messaging on failure."""
assert a != b, msg or "%r == %r" % (a, b)
def le_(a, b, msg=None):
"""Assert a <= b, with repr messaging on failure."""
assert a <= b, msg or "%r != %r" % (a, b)
def is_instance_of(a, b, msg=None):
assert isinstance(a, b), msg or "%r is not an instance of %r" % (a, b)
2021-03-27 16:21:31 -05:00
def is_none(a, msg=None):
is_(a, None, msg=msg)
def is_not_none(a, msg=None):
is_not(a, None, msg=msg)
2021-01-02 23:12:27 -06:00
def is_true(a, msg=None):
2021-03-27 16:21:31 -05:00
is_(bool(a), True, msg=msg)
2021-01-02 23:12:27 -06:00
def is_false(a, msg=None):
2021-03-27 16:21:31 -05:00
is_(bool(a), False, msg=msg)
2021-01-02 23:12:27 -06:00
def is_(a, b, msg=None):
"""Assert a is b, with repr messaging on failure."""
assert a is b, msg or "%r is not %r" % (a, b)
def is_not(a, b, msg=None):
"""Assert a is not b, with repr messaging on failure."""
assert a is not b, msg or "%r is %r" % (a, b)
# deprecated. See #5429
is_not_ = is_not
def in_(a, b, msg=None):
"""Assert a in b, with repr messaging on failure."""
assert a in b, msg or "%r not in %r" % (a, b)
def not_in(a, b, msg=None):
"""Assert a in not b, with repr messaging on failure."""
assert a not in b, msg or "%r is in %r" % (a, b)
# deprecated. See #5429
not_in_ = not_in
def startswith_(a, fragment, msg=None):
"""Assert a.startswith(fragment), with repr messaging on failure."""
assert a.startswith(fragment), msg or "%r does not start with %r" % (
a,
fragment,
)
def eq_ignore_whitespace(a, b, msg=None):
a = re.sub(r"^\s+?|\n", "", a)
a = re.sub(r" {2,}", " ", a)
b = re.sub(r"^\s+?|\n", "", b)
b = re.sub(r" {2,}", " ", b)
assert a == b, msg or "%r != %r" % (a, b)
def _assert_proper_exception_context(exception):
"""assert that any exception we're catching does not have a __context__
without a __cause__, and that __suppress_context__ is never set.
Python 3 will report nested as exceptions as "during the handling of
error X, error Y occurred". That's not what we want to do. we want
these exceptions in a cause chain.
"""
if not util.py3k:
return
if (
exception.__context__ is not exception.__cause__
and not exception.__suppress_context__
):
assert False, (
"Exception %r was correctly raised but did not set a cause, "
"within context %r as its cause."
% (exception, exception.__context__)
)
def assert_raises(except_cls, callable_, *args, **kw):
2021-03-27 16:21:31 -05:00
return _assert_raises(except_cls, callable_, args, kw, check_context=True)
2021-01-02 23:12:27 -06:00
def assert_raises_context_ok(except_cls, callable_, *args, **kw):
2021-03-27 16:21:31 -05:00
return _assert_raises(except_cls, callable_, args, kw)
2021-01-02 23:12:27 -06:00
def assert_raises_message(except_cls, msg, callable_, *args, **kwargs):
2021-03-27 16:21:31 -05:00
return _assert_raises(
2021-01-02 23:12:27 -06:00
except_cls, callable_, args, kwargs, msg=msg, check_context=True
)
def assert_raises_message_context_ok(
except_cls, msg, callable_, *args, **kwargs
):
2021-03-27 16:21:31 -05:00
return _assert_raises(except_cls, callable_, args, kwargs, msg=msg)
2021-01-02 23:12:27 -06:00
def _assert_raises(
except_cls, callable_, args, kwargs, msg=None, check_context=False
):
2021-03-27 16:21:31 -05:00
with _expect_raises(except_cls, msg, check_context) as ec:
callable_(*args, **kwargs)
return ec.error
class _ErrorContainer(object):
error = None
@contextlib.contextmanager
def _expect_raises(except_cls, msg=None, check_context=False):
ec = _ErrorContainer()
2021-01-02 23:12:27 -06:00
if check_context:
are_we_already_in_a_traceback = sys.exc_info()[0]
try:
2021-03-27 16:21:31 -05:00
yield ec
2021-01-02 23:12:27 -06:00
success = False
except except_cls as err:
2021-03-27 16:21:31 -05:00
ec.error = err
2021-01-02 23:12:27 -06:00
success = True
if msg is not None:
assert re.search(
msg, util.text_type(err), re.UNICODE
2021-03-27 16:21:31 -05:00
), "%r !~ %s" % (msg, err)
2021-01-02 23:12:27 -06:00
if check_context and not are_we_already_in_a_traceback:
_assert_proper_exception_context(err)
print(util.text_type(err).encode("utf-8"))
# assert outside the block so it works for AssertionError too !
assert success, "Callable did not raise an exception"
2021-03-27 16:21:31 -05:00
def expect_raises(except_cls, check_context=True):
return _expect_raises(except_cls, check_context=check_context)
def expect_raises_message(except_cls, msg, check_context=True):
return _expect_raises(except_cls, msg=msg, check_context=check_context)
2021-01-02 23:12:27 -06:00
class AssertsCompiledSQL(object):
def assert_compile(
self,
clause,
result,
params=None,
checkparams=None,
2021-03-27 16:21:31 -05:00
for_executemany=False,
check_literal_execute=None,
check_post_param=None,
2021-01-02 23:12:27 -06:00
dialect=None,
checkpositional=None,
check_prefetch=None,
use_default_dialect=False,
allow_dialect_select=False,
2021-03-27 16:21:31 -05:00
supports_default_values=True,
2021-01-02 23:12:27 -06:00
literal_binds=False,
2021-03-27 16:21:31 -05:00
render_postcompile=False,
2021-01-02 23:12:27 -06:00
schema_translate_map=None,
2021-03-27 16:21:31 -05:00
render_schema_translate=False,
default_schema_name=None,
from_linting=False,
2021-01-02 23:12:27 -06:00
):
if use_default_dialect:
dialect = default.DefaultDialect()
2021-03-27 16:21:31 -05:00
dialect.supports_default_values = supports_default_values
2021-01-02 23:12:27 -06:00
elif allow_dialect_select:
dialect = None
else:
if dialect is None:
dialect = getattr(self, "__dialect__", None)
if dialect is None:
dialect = config.db.dialect
elif dialect == "default":
dialect = default.DefaultDialect()
2021-03-27 16:21:31 -05:00
dialect.supports_default_values = supports_default_values
2021-01-02 23:12:27 -06:00
elif dialect == "default_enhanced":
dialect = default.StrCompileDialect()
elif isinstance(dialect, util.string_types):
2021-03-27 16:21:31 -05:00
dialect = url.URL.create(dialect).get_dialect()()
if default_schema_name:
dialect.default_schema_name = default_schema_name
2021-01-02 23:12:27 -06:00
kw = {}
compile_kwargs = {}
if schema_translate_map:
kw["schema_translate_map"] = schema_translate_map
if params is not None:
kw["column_keys"] = list(params)
if literal_binds:
compile_kwargs["literal_binds"] = True
2021-03-27 16:21:31 -05:00
if render_postcompile:
compile_kwargs["render_postcompile"] = True
if for_executemany:
kw["for_executemany"] = True
if render_schema_translate:
kw["render_schema_translate"] = True
if from_linting or getattr(self, "assert_from_linting", False):
kw["linting"] = sql.FROM_LINTING
2021-01-02 23:12:27 -06:00
from sqlalchemy import orm
if isinstance(clause, orm.Query):
2021-03-27 16:21:31 -05:00
stmt = clause._statement_20()
stmt._label_style = LABEL_STYLE_TABLENAME_PLUS_COL
clause = stmt
2021-01-02 23:12:27 -06:00
if compile_kwargs:
kw["compile_kwargs"] = compile_kwargs
class DontAccess(object):
def __getattribute__(self, key):
raise NotImplementedError(
"compiler accessed .statement; use "
"compiler.current_executable"
)
class CheckCompilerAccess(object):
def __init__(self, test_statement):
self.test_statement = test_statement
2021-03-27 16:21:31 -05:00
self._annotations = {}
2021-01-02 23:12:27 -06:00
self.supports_execution = getattr(
test_statement, "supports_execution", False
)
if self.supports_execution:
self._execution_options = test_statement._execution_options
if isinstance(
test_statement, (sql.Insert, sql.Update, sql.Delete)
):
self._returning = test_statement._returning
if isinstance(test_statement, (sql.Insert, sql.Update)):
2021-03-27 16:21:31 -05:00
self._inline = test_statement._inline
2021-01-02 23:12:27 -06:00
self._return_defaults = test_statement._return_defaults
def _default_dialect(self):
return self.test_statement._default_dialect()
def compile(self, dialect, **kw):
return self.test_statement.compile.__func__(
self, dialect=dialect, **kw
)
def _compiler(self, dialect, **kw):
return self.test_statement._compiler.__func__(
self, dialect, **kw
)
def _compiler_dispatch(self, compiler, **kwargs):
if hasattr(compiler, "statement"):
with mock.patch.object(
compiler, "statement", DontAccess()
):
return self.test_statement._compiler_dispatch(
compiler, **kwargs
)
else:
return self.test_statement._compiler_dispatch(
compiler, **kwargs
)
# no construct can assume it's the "top level" construct in all cases
# as anything can be nested. ensure constructs don't assume they
# are the "self.statement" element
c = CheckCompilerAccess(clause).compile(dialect=dialect, **kw)
param_str = repr(getattr(c, "params", {}))
if util.py3k:
param_str = param_str.encode("utf-8").decode("ascii", "ignore")
print(
("\nSQL String:\n" + util.text_type(c) + param_str).encode(
"utf-8"
)
)
else:
print(
"\nSQL String:\n"
+ util.text_type(c).encode("utf-8")
+ param_str
)
cc = re.sub(r"[\n\t]", "", util.text_type(c))
eq_(cc, result, "%r != %r on dialect %r" % (cc, result, dialect))
if checkparams is not None:
eq_(c.construct_params(params), checkparams)
if checkpositional is not None:
p = c.construct_params(params)
eq_(tuple([p[x] for x in c.positiontup]), checkpositional)
if check_prefetch is not None:
eq_(c.prefetch, check_prefetch)
2021-03-27 16:21:31 -05:00
if check_literal_execute is not None:
eq_(
{
c.bind_names[b]: b.effective_value
for b in c.literal_execute_params
},
check_literal_execute,
)
if check_post_param is not None:
eq_(
{
c.bind_names[b]: b.effective_value
for b in c.post_compile_params
},
check_post_param,
)
2021-01-02 23:12:27 -06:00
class ComparesTables(object):
def assert_tables_equal(self, table, reflected_table, strict_types=False):
assert len(table.c) == len(reflected_table.c)
for c, reflected_c in zip(table.c, reflected_table.c):
eq_(c.name, reflected_c.name)
assert reflected_c is reflected_table.c[c.name]
eq_(c.primary_key, reflected_c.primary_key)
eq_(c.nullable, reflected_c.nullable)
if strict_types:
msg = "Type '%s' doesn't correspond to type '%s'"
assert isinstance(reflected_c.type, type(c.type)), msg % (
reflected_c.type,
c.type,
)
else:
self.assert_types_base(reflected_c, c)
if isinstance(c.type, sqltypes.String):
eq_(c.type.length, reflected_c.type.length)
eq_(
{f.column.name for f in c.foreign_keys},
{f.column.name for f in reflected_c.foreign_keys},
)
if c.server_default:
assert isinstance(
reflected_c.server_default, schema.FetchedValue
)
assert len(table.primary_key) == len(reflected_table.primary_key)
for c in table.primary_key:
assert reflected_table.primary_key.columns[c.name] is not None
def assert_types_base(self, c1, c2):
assert c1.type._compare_type_affinity(
c2.type
), "On column %r, type '%s' doesn't correspond to type '%s'" % (
c1.name,
c1.type,
c2.type,
)
class AssertsExecutionResults(object):
def assert_result(self, result, class_, *objects):
result = list(result)
print(repr(result))
self.assert_list(result, class_, objects)
def assert_list(self, result, class_, list_):
self.assert_(
len(result) == len(list_),
"result list is not the same size as test list, "
+ "for class "
+ class_.__name__,
)
for i in range(0, len(list_)):
self.assert_row(class_, result[i], list_[i])
def assert_row(self, class_, rowobj, desc):
self.assert_(
rowobj.__class__ is class_, "item class is not " + repr(class_)
)
for key, value in desc.items():
if isinstance(value, tuple):
if isinstance(value[1], list):
self.assert_list(getattr(rowobj, key), value[0], value[1])
else:
self.assert_row(value[0], getattr(rowobj, key), value[1])
else:
self.assert_(
getattr(rowobj, key) == value,
"attribute %s value %s does not match %s"
% (key, getattr(rowobj, key), value),
)
def assert_unordered_result(self, result, cls, *expected):
"""As assert_result, but the order of objects is not considered.
The algorithm is very expensive but not a big deal for the small
numbers of rows that the test suite manipulates.
"""
class immutabledict(dict):
def __hash__(self):
return id(self)
found = util.IdentitySet(result)
expected = {immutabledict(e) for e in expected}
for wrong in util.itertools_filterfalse(
lambda o: isinstance(o, cls), found
):
fail(
'Unexpected type "%s", expected "%s"'
% (type(wrong).__name__, cls.__name__)
)
if len(found) != len(expected):
fail(
'Unexpected object count "%s", expected "%s"'
% (len(found), len(expected))
)
NOVALUE = object()
def _compare_item(obj, spec):
for key, value in spec.items():
if isinstance(value, tuple):
try:
self.assert_unordered_result(
getattr(obj, key), value[0], *value[1]
)
except AssertionError:
return False
else:
if getattr(obj, key, NOVALUE) != value:
return False
return True
for expected_item in expected:
for found_item in found:
if _compare_item(found_item, expected_item):
found.remove(found_item)
break
else:
fail(
"Expected %s instance with attributes %s not found."
% (cls.__name__, repr(expected_item))
)
return True
def sql_execution_asserter(self, db=None):
if db is None:
from . import db as db
return assertsql.assert_engine(db)
def assert_sql_execution(self, db, callable_, *rules):
with self.sql_execution_asserter(db) as asserter:
result = callable_()
asserter.assert_(*rules)
return result
def assert_sql(self, db, callable_, rules):
newrules = []
for rule in rules:
if isinstance(rule, dict):
newrule = assertsql.AllOf(
*[assertsql.CompiledSQL(k, v) for k, v in rule.items()]
)
else:
newrule = assertsql.CompiledSQL(*rule)
newrules.append(newrule)
return self.assert_sql_execution(db, callable_, *newrules)
def assert_sql_count(self, db, callable_, count):
self.assert_sql_execution(
db, callable_, assertsql.CountStatements(count)
)
def assert_multiple_sql_count(self, dbs, callable_, counts):
recs = [
(self.sql_execution_asserter(db), db, count)
for (db, count) in zip(dbs, counts)
]
asserters = []
for ctx, db, count in recs:
asserters.append(ctx.__enter__())
try:
return callable_()
finally:
for asserter, (ctx, db, count) in zip(asserters, recs):
ctx.__exit__(None, None, None)
asserter.assert_(assertsql.CountStatements(count))
@contextlib.contextmanager
def assert_execution(self, db, *rules):
with self.sql_execution_asserter(db) as asserter:
yield
asserter.assert_(*rules)
def assert_statement_count(self, db, count):
return self.assert_execution(db, assertsql.CountStatements(count))