| Viewing file:  _autogen_fixtures.py (9.65 KB)      -rw-r--r-- Select action/file-type:
 
  (+) |  (+) |  (+) | Code (+) | Session (+) |  (+) | SDB (+) |  (+) |  (+) |  (+) |  (+) |  (+) | 
 
from __future__ import annotations
 from typing import Any
 from typing import Dict
 from typing import Set
 
 from sqlalchemy import CHAR
 from sqlalchemy import CheckConstraint
 from sqlalchemy import Column
 from sqlalchemy import event
 from sqlalchemy import ForeignKey
 from sqlalchemy import Index
 from sqlalchemy import inspect
 from sqlalchemy import Integer
 from sqlalchemy import MetaData
 from sqlalchemy import Numeric
 from sqlalchemy import String
 from sqlalchemy import Table
 from sqlalchemy import Text
 from sqlalchemy import text
 from sqlalchemy import UniqueConstraint
 
 from ... import autogenerate
 from ... import util
 from ...autogenerate import api
 from ...ddl.base import _fk_spec
 from ...migration import MigrationContext
 from ...operations import ops
 from ...testing import config
 from ...testing import eq_
 from ...testing.env import clear_staging_env
 from ...testing.env import staging_env
 
 names_in_this_test: Set[Any] = set()
 
 
 @event.listens_for(Table, "after_parent_attach")
 def new_table(table, parent):
 names_in_this_test.add(table.name)
 
 
 def _default_include_object(obj, name, type_, reflected, compare_to):
 if type_ == "table":
 return name in names_in_this_test
 else:
 return True
 
 
 _default_object_filters: Any = _default_include_object
 
 _default_name_filters: Any = None
 
 
 class ModelOne:
 __requires__ = ("unique_constraint_reflection",)
 
 schema: Any = None
 
 @classmethod
 def _get_db_schema(cls):
 schema = cls.schema
 
 m = MetaData(schema=schema)
 
 Table(
 "user",
 m,
 Column("id", Integer, primary_key=True),
 Column("name", String(50)),
 Column("a1", Text),
 Column("pw", String(50)),
 Index("pw_idx", "pw"),
 )
 
 Table(
 "address",
 m,
 Column("id", Integer, primary_key=True),
 Column("email_address", String(100), nullable=False),
 )
 
 Table(
 "order",
 m,
 Column("order_id", Integer, primary_key=True),
 Column(
 "amount",
 Numeric(8, 2),
 nullable=False,
 server_default=text("0"),
 ),
 CheckConstraint("amount >= 0", name="ck_order_amount"),
 )
 
 Table(
 "extra",
 m,
 Column("x", CHAR),
 Column("uid", Integer, ForeignKey("user.id")),
 )
 
 return m
 
 @classmethod
 def _get_model_schema(cls):
 schema = cls.schema
 
 m = MetaData(schema=schema)
 
 Table(
 "user",
 m,
 Column("id", Integer, primary_key=True),
 Column("name", String(50), nullable=False),
 Column("a1", Text, server_default="x"),
 )
 
 Table(
 "address",
 m,
 Column("id", Integer, primary_key=True),
 Column("email_address", String(100), nullable=False),
 Column("street", String(50)),
 UniqueConstraint("email_address", name="uq_email"),
 )
 
 Table(
 "order",
 m,
 Column("order_id", Integer, primary_key=True),
 Column(
 "amount",
 Numeric(10, 2),
 nullable=True,
 server_default=text("0"),
 ),
 Column("user_id", Integer, ForeignKey("user.id")),
 CheckConstraint("amount > -1", name="ck_order_amount"),
 )
 
 Table(
 "item",
 m,
 Column("id", Integer, primary_key=True),
 Column("description", String(100)),
 Column("order_id", Integer, ForeignKey("order.order_id")),
 CheckConstraint("len(description) > 5"),
 )
 return m
 
 
 class _ComparesFKs:
 def _assert_fk_diff(
 self,
 diff,
 type_,
 source_table,
 source_columns,
 target_table,
 target_columns,
 name=None,
 conditional_name=None,
 source_schema=None,
 onupdate=None,
 ondelete=None,
 initially=None,
 deferrable=None,
 ):
 # the public API for ForeignKeyConstraint was not very rich
 # in 0.7, 0.8, so here we use the well-known but slightly
 # private API to get at its elements
 (
 fk_source_schema,
 fk_source_table,
 fk_source_columns,
 fk_target_schema,
 fk_target_table,
 fk_target_columns,
 fk_onupdate,
 fk_ondelete,
 fk_deferrable,
 fk_initially,
 ) = _fk_spec(diff[1])
 
 eq_(diff[0], type_)
 eq_(fk_source_table, source_table)
 eq_(fk_source_columns, source_columns)
 eq_(fk_target_table, target_table)
 eq_(fk_source_schema, source_schema)
 eq_(fk_onupdate, onupdate)
 eq_(fk_ondelete, ondelete)
 eq_(fk_initially, initially)
 eq_(fk_deferrable, deferrable)
 
 eq_([elem.column.name for elem in diff[1].elements], target_columns)
 if conditional_name is not None:
 if conditional_name == "servergenerated":
 fks = inspect(self.bind).get_foreign_keys(source_table)
 server_fk_name = fks[0]["name"]
 eq_(diff[1].name, server_fk_name)
 else:
 eq_(diff[1].name, conditional_name)
 else:
 eq_(diff[1].name, name)
 
 
 class AutogenTest(_ComparesFKs):
 def _flatten_diffs(self, diffs):
 for d in diffs:
 if isinstance(d, list):
 yield from self._flatten_diffs(d)
 else:
 yield d
 
 @classmethod
 def _get_bind(cls):
 return config.db
 
 configure_opts: Dict[Any, Any] = {}
 
 @classmethod
 def setup_class(cls):
 staging_env()
 cls.bind = cls._get_bind()
 cls.m1 = cls._get_db_schema()
 cls.m1.create_all(cls.bind)
 cls.m2 = cls._get_model_schema()
 
 @classmethod
 def teardown_class(cls):
 cls.m1.drop_all(cls.bind)
 clear_staging_env()
 
 def setUp(self):
 self.conn = conn = self.bind.connect()
 ctx_opts = {
 "compare_type": True,
 "compare_server_default": True,
 "target_metadata": self.m2,
 "upgrade_token": "upgrades",
 "downgrade_token": "downgrades",
 "alembic_module_prefix": "op.",
 "sqlalchemy_module_prefix": "sa.",
 "include_object": _default_object_filters,
 "include_name": _default_name_filters,
 }
 if self.configure_opts:
 ctx_opts.update(self.configure_opts)
 self.context = context = MigrationContext.configure(
 connection=conn, opts=ctx_opts
 )
 
 self.autogen_context = api.AutogenContext(context, self.m2)
 
 def tearDown(self):
 self.conn.close()
 
 def _update_context(
 self, object_filters=None, name_filters=None, include_schemas=None
 ):
 if include_schemas is not None:
 self.autogen_context.opts["include_schemas"] = include_schemas
 if object_filters is not None:
 self.autogen_context._object_filters = [object_filters]
 if name_filters is not None:
 self.autogen_context._name_filters = [name_filters]
 return self.autogen_context
 
 
 class AutogenFixtureTest(_ComparesFKs):
 def _fixture(
 self,
 m1,
 m2,
 include_schemas=False,
 opts=None,
 object_filters=_default_object_filters,
 name_filters=_default_name_filters,
 return_ops=False,
 max_identifier_length=None,
 ):
 
 if max_identifier_length:
 dialect = self.bind.dialect
 existing_length = dialect.max_identifier_length
 dialect.max_identifier_length = (
 dialect._user_defined_max_identifier_length
 ) = max_identifier_length
 try:
 self._alembic_metadata, model_metadata = m1, m2
 for m in util.to_list(self._alembic_metadata):
 m.create_all(self.bind)
 
 with self.bind.connect() as conn:
 ctx_opts = {
 "compare_type": True,
 "compare_server_default": True,
 "target_metadata": model_metadata,
 "upgrade_token": "upgrades",
 "downgrade_token": "downgrades",
 "alembic_module_prefix": "op.",
 "sqlalchemy_module_prefix": "sa.",
 "include_object": object_filters,
 "include_name": name_filters,
 "include_schemas": include_schemas,
 }
 if opts:
 ctx_opts.update(opts)
 self.context = context = MigrationContext.configure(
 connection=conn, opts=ctx_opts
 )
 
 autogen_context = api.AutogenContext(context, model_metadata)
 uo = ops.UpgradeOps(ops=[])
 autogenerate._produce_net_changes(autogen_context, uo)
 
 if return_ops:
 return uo
 else:
 return uo.as_diffs()
 finally:
 if max_identifier_length:
 dialect = self.bind.dialect
 dialect.max_identifier_length = (
 dialect._user_defined_max_identifier_length
 ) = existing_length
 
 def setUp(self):
 staging_env()
 self.bind = config.db
 
 def tearDown(self):
 if hasattr(self, "_alembic_metadata"):
 for m in util.to_list(self._alembic_metadata):
 m.drop_all(self.bind)
 clear_staging_env()
 
 |