| Viewing file:  clsregistry.py (12.27 KB)      -rw-r--r-- Select action/file-type:
 
  (+) |  (+) |  (+) | Code (+) | Session (+) |  (+) | SDB (+) |  (+) |  (+) |  (+) |  (+) |  (+) | 
 
# ext/declarative/clsregistry.py# Copyright (C) 2005-2021 the SQLAlchemy authors and contributors
 # <see AUTHORS file>
 #
 # This module is part of SQLAlchemy and is released under
 # the MIT License: http://www.opensource.org/licenses/mit-license.php
 """Routines to handle the string class registry used by declarative.
 
 This system allows specification of classes and expressions used in
 :func:`_orm.relationship` using strings.
 
 """
 import weakref
 
 from ... import exc
 from ... import inspection
 from ... import util
 from ...orm import class_mapper
 from ...orm import interfaces
 from ...orm.properties import ColumnProperty
 from ...orm.properties import RelationshipProperty
 from ...orm.properties import SynonymProperty
 from ...schema import _get_table_key
 
 
 # strong references to registries which we place in
 # the _decl_class_registry, which is usually weak referencing.
 # the internal registries here link to classes with weakrefs and remove
 # themselves when all references to contained classes are removed.
 _registries = set()
 
 
 def add_class(classname, cls):
 """Add a class to the _decl_class_registry associated with the
 given declarative class.
 
 """
 if classname in cls._decl_class_registry:
 # class already exists.
 existing = cls._decl_class_registry[classname]
 if not isinstance(existing, _MultipleClassMarker):
 existing = cls._decl_class_registry[
 classname
 ] = _MultipleClassMarker([cls, existing])
 else:
 cls._decl_class_registry[classname] = cls
 
 try:
 root_module = cls._decl_class_registry["_sa_module_registry"]
 except KeyError:
 cls._decl_class_registry[
 "_sa_module_registry"
 ] = root_module = _ModuleMarker("_sa_module_registry", None)
 
 tokens = cls.__module__.split(".")
 
 # build up a tree like this:
 # modulename:  myapp.snacks.nuts
 #
 # myapp->snack->nuts->(classes)
 # snack->nuts->(classes)
 # nuts->(classes)
 #
 # this allows partial token paths to be used.
 while tokens:
 token = tokens.pop(0)
 module = root_module.get_module(token)
 for token in tokens:
 module = module.get_module(token)
 module.add_class(classname, cls)
 
 
 class _MultipleClassMarker(object):
 """refers to multiple classes of the same name
 within _decl_class_registry.
 
 """
 
 __slots__ = "on_remove", "contents", "__weakref__"
 
 def __init__(self, classes, on_remove=None):
 self.on_remove = on_remove
 self.contents = set(
 [weakref.ref(item, self._remove_item) for item in classes]
 )
 _registries.add(self)
 
 def __iter__(self):
 return (ref() for ref in self.contents)
 
 def attempt_get(self, path, key):
 if len(self.contents) > 1:
 raise exc.InvalidRequestError(
 'Multiple classes found for path "%s" '
 "in the registry of this declarative "
 "base. Please use a fully module-qualified path."
 % (".".join(path + [key]))
 )
 else:
 ref = list(self.contents)[0]
 cls = ref()
 if cls is None:
 raise NameError(key)
 return cls
 
 def _remove_item(self, ref):
 self.contents.remove(ref)
 if not self.contents:
 _registries.discard(self)
 if self.on_remove:
 self.on_remove()
 
 def add_item(self, item):
 # protect against class registration race condition against
 # asynchronous garbage collection calling _remove_item,
 # [ticket:3208]
 modules = set(
 [
 cls.__module__
 for cls in [ref() for ref in self.contents]
 if cls is not None
 ]
 )
 if item.__module__ in modules:
 util.warn(
 "This declarative base already contains a class with the "
 "same class name and module name as %s.%s, and will "
 "be replaced in the string-lookup table."
 % (item.__module__, item.__name__)
 )
 self.contents.add(weakref.ref(item, self._remove_item))
 
 
 class _ModuleMarker(object):
 """Refers to a module name within
 _decl_class_registry.
 
 """
 
 __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__"
 
 def __init__(self, name, parent):
 self.parent = parent
 self.name = name
 self.contents = {}
 self.mod_ns = _ModNS(self)
 if self.parent:
 self.path = self.parent.path + [self.name]
 else:
 self.path = []
 _registries.add(self)
 
 def __contains__(self, name):
 return name in self.contents
 
 def __getitem__(self, name):
 return self.contents[name]
 
 def _remove_item(self, name):
 self.contents.pop(name, None)
 if not self.contents and self.parent is not None:
 self.parent._remove_item(self.name)
 _registries.discard(self)
 
 def resolve_attr(self, key):
 return getattr(self.mod_ns, key)
 
 def get_module(self, name):
 if name not in self.contents:
 marker = _ModuleMarker(name, self)
 self.contents[name] = marker
 else:
 marker = self.contents[name]
 return marker
 
 def add_class(self, name, cls):
 if name in self.contents:
 existing = self.contents[name]
 existing.add_item(cls)
 else:
 existing = self.contents[name] = _MultipleClassMarker(
 [cls], on_remove=lambda: self._remove_item(name)
 )
 
 
 class _ModNS(object):
 __slots__ = ("__parent",)
 
 def __init__(self, parent):
 self.__parent = parent
 
 def __getattr__(self, key):
 try:
 value = self.__parent.contents[key]
 except KeyError:
 pass
 else:
 if value is not None:
 if isinstance(value, _ModuleMarker):
 return value.mod_ns
 else:
 assert isinstance(value, _MultipleClassMarker)
 return value.attempt_get(self.__parent.path, key)
 raise AttributeError(
 "Module %r has no mapped classes "
 "registered under the name %r" % (self.__parent.name, key)
 )
 
 
 class _GetColumns(object):
 __slots__ = ("cls",)
 
 def __init__(self, cls):
 self.cls = cls
 
 def __getattr__(self, key):
 mp = class_mapper(self.cls, configure=False)
 if mp:
 if key not in mp.all_orm_descriptors:
 raise exc.InvalidRequestError(
 "Class %r does not have a mapped column named %r"
 % (self.cls, key)
 )
 
 desc = mp.all_orm_descriptors[key]
 if desc.extension_type is interfaces.NOT_EXTENSION:
 prop = desc.property
 if isinstance(prop, SynonymProperty):
 key = prop.name
 elif not isinstance(prop, ColumnProperty):
 raise exc.InvalidRequestError(
 "Property %r is not an instance of"
 " ColumnProperty (i.e. does not correspond"
 " directly to a Column)." % key
 )
 return getattr(self.cls, key)
 
 
 inspection._inspects(_GetColumns)(
 lambda target: inspection.inspect(target.cls)
 )
 
 
 class _GetTable(object):
 __slots__ = "key", "metadata"
 
 def __init__(self, key, metadata):
 self.key = key
 self.metadata = metadata
 
 def __getattr__(self, key):
 return self.metadata.tables[_get_table_key(key, self.key)]
 
 
 def _determine_container(key, value):
 if isinstance(value, _MultipleClassMarker):
 value = value.attempt_get([], key)
 return _GetColumns(value)
 
 
 class _class_resolver(object):
 def __init__(self, cls, prop, fallback, arg, favor_tables=False):
 self.cls = cls
 self.prop = prop
 self.arg = self._declarative_arg = arg
 self.fallback = fallback
 self._dict = util.PopulateDict(self._access_cls)
 self._resolvers = ()
 self.favor_tables = favor_tables
 
 def _access_cls(self, key):
 cls = self.cls
 
 if self.favor_tables:
 if key in cls.metadata.tables:
 return cls.metadata.tables[key]
 elif key in cls.metadata._schemas:
 return _GetTable(key, cls.metadata)
 
 if key in cls._decl_class_registry:
 return _determine_container(key, cls._decl_class_registry[key])
 
 if not self.favor_tables:
 if key in cls.metadata.tables:
 return cls.metadata.tables[key]
 elif key in cls.metadata._schemas:
 return _GetTable(key, cls.metadata)
 
 if (
 "_sa_module_registry" in cls._decl_class_registry
 and key in cls._decl_class_registry["_sa_module_registry"]
 ):
 registry = cls._decl_class_registry["_sa_module_registry"]
 return registry.resolve_attr(key)
 elif self._resolvers:
 for resolv in self._resolvers:
 value = resolv(key)
 if value is not None:
 return value
 
 return self.fallback[key]
 
 def _raise_for_name(self, name, err):
 util.raise_(
 exc.InvalidRequestError(
 "When initializing mapper %s, expression %r failed to "
 "locate a name (%r). If this is a class name, consider "
 "adding this relationship() to the %r class after "
 "both dependent classes have been defined."
 % (self.prop.parent, self.arg, name, self.cls)
 ),
 from_=err,
 )
 
 def _resolve_name(self):
 name = self.arg
 d = self._dict
 rval = None
 try:
 for token in name.split("."):
 if rval is None:
 rval = d[token]
 else:
 rval = getattr(rval, token)
 except KeyError as err:
 self._raise_for_name(name, err)
 except NameError as n:
 self._raise_for_name(n.args[0], n)
 else:
 if isinstance(rval, _GetColumns):
 return rval.cls
 else:
 return rval
 
 def __call__(self):
 try:
 x = eval(self.arg, globals(), self._dict)
 
 if isinstance(x, _GetColumns):
 return x.cls
 else:
 return x
 except NameError as n:
 self._raise_for_name(n.args[0], n)
 
 
 def _resolver(cls, prop):
 import sqlalchemy
 from sqlalchemy.orm import foreign, remote
 
 fallback = sqlalchemy.__dict__.copy()
 fallback.update({"foreign": foreign, "remote": remote})
 
 def resolve_arg(arg, favor_tables=False):
 return _class_resolver(
 cls, prop, fallback, arg, favor_tables=favor_tables
 )
 
 def resolve_name(arg):
 return _class_resolver(cls, prop, fallback, arg)._resolve_name
 
 return resolve_name, resolve_arg
 
 
 def _deferred_relationship(cls, prop):
 
 if isinstance(prop, RelationshipProperty):
 resolve_name, resolve_arg = _resolver(cls, prop)
 
 for attr in (
 "order_by",
 "primaryjoin",
 "secondaryjoin",
 "secondary",
 "_user_defined_foreign_keys",
 "remote_side",
 ):
 v = getattr(prop, attr)
 if isinstance(v, util.string_types):
 setattr(
 prop,
 attr,
 resolve_arg(v, favor_tables=attr == "secondary"),
 )
 
 for attr in ("argument",):
 v = getattr(prop, attr)
 if isinstance(v, util.string_types):
 setattr(prop, attr, resolve_name(v))
 
 if prop.backref and isinstance(prop.backref, tuple):
 key, kwargs = prop.backref
 for attr in (
 "primaryjoin",
 "secondaryjoin",
 "secondary",
 "foreign_keys",
 "remote_side",
 "order_by",
 ):
 if attr in kwargs and isinstance(
 kwargs[attr], util.string_types
 ):
 kwargs[attr] = resolve_arg(kwargs[attr])
 
 return prop
 
 |