# ext/declarative/clsregistry.py # Copyright (C) 2005-2020 the SQLAlchemy authors and contributors # # # This module is part of SQLAlchemy and is released under # the MIT License: http://www.opensource.org/licenses/mit-license.php """Routines to handle the string class registry used by declarative. This system allows specification of classes and expressions used in :func:`.relationship` using strings. """ import weakref from ... import exc from ... import inspection from ... import util from ...orm import class_mapper from ...orm import interfaces from ...orm.properties import ColumnProperty from ...orm.properties import RelationshipProperty from ...orm.properties import SynonymProperty from ...schema import _get_table_key # strong references to registries which we place in # the _decl_class_registry, which is usually weak referencing. # the internal registries here link to classes with weakrefs and remove # themselves when all references to contained classes are removed. _registries = set() def add_class(classname, cls): """Add a class to the _decl_class_registry associated with the given declarative class. """ if classname in cls._decl_class_registry: # class already exists. existing = cls._decl_class_registry[classname] if not isinstance(existing, _MultipleClassMarker): existing = cls._decl_class_registry[ classname ] = _MultipleClassMarker([cls, existing]) else: cls._decl_class_registry[classname] = cls try: root_module = cls._decl_class_registry["_sa_module_registry"] except KeyError: cls._decl_class_registry[ "_sa_module_registry" ] = root_module = _ModuleMarker("_sa_module_registry", None) tokens = cls.__module__.split(".") # build up a tree like this: # modulename: myapp.snacks.nuts # # myapp->snack->nuts->(classes) # snack->nuts->(classes) # nuts->(classes) # # this allows partial token paths to be used. while tokens: token = tokens.pop(0) module = root_module.get_module(token) for token in tokens: module = module.get_module(token) module.add_class(classname, cls) class _MultipleClassMarker(object): """refers to multiple classes of the same name within _decl_class_registry. """ __slots__ = "on_remove", "contents", "__weakref__" def __init__(self, classes, on_remove=None): self.on_remove = on_remove self.contents = set( [weakref.ref(item, self._remove_item) for item in classes] ) _registries.add(self) def __iter__(self): return (ref() for ref in self.contents) def attempt_get(self, path, key): if len(self.contents) > 1: raise exc.InvalidRequestError( 'Multiple classes found for path "%s" ' "in the registry of this declarative " "base. Please use a fully module-qualified path." % (".".join(path + [key])) ) else: ref = list(self.contents)[0] cls = ref() if cls is None: raise NameError(key) return cls def _remove_item(self, ref): self.contents.remove(ref) if not self.contents: _registries.discard(self) if self.on_remove: self.on_remove() def add_item(self, item): # protect against class registration race condition against # asynchronous garbage collection calling _remove_item, # [ticket:3208] modules = set( [ cls.__module__ for cls in [ref() for ref in self.contents] if cls is not None ] ) if item.__module__ in modules: util.warn( "This declarative base already contains a class with the " "same class name and module name as %s.%s, and will " "be replaced in the string-lookup table." % (item.__module__, item.__name__) ) self.contents.add(weakref.ref(item, self._remove_item)) class _ModuleMarker(object): """"refers to a module name within _decl_class_registry. """ __slots__ = "parent", "name", "contents", "mod_ns", "path", "__weakref__" def __init__(self, name, parent): self.parent = parent self.name = name self.contents = {} self.mod_ns = _ModNS(self) if self.parent: self.path = self.parent.path + [self.name] else: self.path = [] _registries.add(self) def __contains__(self, name): return name in self.contents def __getitem__(self, name): return self.contents[name] def _remove_item(self, name): self.contents.pop(name, None) if not self.contents and self.parent is not None: self.parent._remove_item(self.name) _registries.discard(self) def resolve_attr(self, key): return getattr(self.mod_ns, key) def get_module(self, name): if name not in self.contents: marker = _ModuleMarker(name, self) self.contents[name] = marker else: marker = self.contents[name] return marker def add_class(self, name, cls): if name in self.contents: existing = self.contents[name] existing.add_item(cls) else: existing = self.contents[name] = _MultipleClassMarker( [cls], on_remove=lambda: self._remove_item(name) ) class _ModNS(object): __slots__ = ("__parent",) def __init__(self, parent): self.__parent = parent def __getattr__(self, key): try: value = self.__parent.contents[key] except KeyError: pass else: if value is not None: if isinstance(value, _ModuleMarker): return value.mod_ns else: assert isinstance(value, _MultipleClassMarker) return value.attempt_get(self.__parent.path, key) raise AttributeError( "Module %r has no mapped classes " "registered under the name %r" % (self.__parent.name, key) ) class _GetColumns(object): __slots__ = ("cls",) def __init__(self, cls): self.cls = cls def __getattr__(self, key): mp = class_mapper(self.cls, configure=False) if mp: if key not in mp.all_orm_descriptors: raise exc.InvalidRequestError( "Class %r does not have a mapped column named %r" % (self.cls, key) ) desc = mp.all_orm_descriptors[key] if desc.extension_type is interfaces.NOT_EXTENSION: prop = desc.property if isinstance(prop, SynonymProperty): key = prop.name elif not isinstance(prop, ColumnProperty): raise exc.InvalidRequestError( "Property %r is not an instance of" " ColumnProperty (i.e. does not correspond" " directly to a Column)." % key ) return getattr(self.cls, key) inspection._inspects(_GetColumns)( lambda target: inspection.inspect(target.cls) ) class _GetTable(object): __slots__ = "key", "metadata" def __init__(self, key, metadata): self.key = key self.metadata = metadata def __getattr__(self, key): return self.metadata.tables[_get_table_key(key, self.key)] def _determine_container(key, value): if isinstance(value, _MultipleClassMarker): value = value.attempt_get([], key) return _GetColumns(value) class _class_resolver(object): def __init__(self, cls, prop, fallback, arg): self.cls = cls self.prop = prop self.arg = self._declarative_arg = arg self.fallback = fallback self._dict = util.PopulateDict(self._access_cls) self._resolvers = () def _access_cls(self, key): cls = self.cls if key in cls._decl_class_registry: return _determine_container(key, cls._decl_class_registry[key]) elif key in cls.metadata.tables: return cls.metadata.tables[key] elif key in cls.metadata._schemas: return _GetTable(key, cls.metadata) elif ( "_sa_module_registry" in cls._decl_class_registry and key in cls._decl_class_registry["_sa_module_registry"] ): registry = cls._decl_class_registry["_sa_module_registry"] return registry.resolve_attr(key) elif self._resolvers: for resolv in self._resolvers: value = resolv(key) if value is not None: return value return self.fallback[key] def __call__(self): try: x = eval(self.arg, globals(), self._dict) if isinstance(x, _GetColumns): return x.cls else: return x except NameError as n: raise exc.InvalidRequestError( "When initializing mapper %s, expression %r failed to " "locate a name (%r). If this is a class name, consider " "adding this relationship() to the %r class after " "both dependent classes have been defined." % (self.prop.parent, self.arg, n.args[0], self.cls) ) def _resolver(cls, prop): import sqlalchemy from sqlalchemy.orm import foreign, remote fallback = sqlalchemy.__dict__.copy() fallback.update({"foreign": foreign, "remote": remote}) def resolve_arg(arg): return _class_resolver(cls, prop, fallback, arg) return resolve_arg def _deferred_relationship(cls, prop): if isinstance(prop, RelationshipProperty): resolve_arg = _resolver(cls, prop) for attr in ( "argument", "order_by", "primaryjoin", "secondaryjoin", "secondary", "_user_defined_foreign_keys", "remote_side", ): v = getattr(prop, attr) if isinstance(v, util.string_types): setattr(prop, attr, resolve_arg(v)) if prop.backref and isinstance(prop.backref, tuple): key, kwargs = prop.backref for attr in ( "primaryjoin", "secondaryjoin", "secondary", "foreign_keys", "remote_side", "order_by", ): if attr in kwargs and isinstance( kwargs[attr], util.string_types ): kwargs[attr] = resolve_arg(kwargs[attr]) return prop