Source code for sqlobject.joins

from itertools import count
from . import boundattributes
from . import classregistry
from . import events
from . import sresults
from . import styles
from . import sqlbuilder
from .styles import capword

__all__ = ['MultipleJoin', 'SQLMultipleJoin', 'RelatedJoin', 'SQLRelatedJoin',
           'SingleJoin', 'ManyToMany', 'OneToMany']

creationOrder = count()
NoDefault = sqlbuilder.NoDefault


def getID(obj):
    try:
        return obj.id
    except AttributeError:
        return int(obj)


class Join(object):

    def __init__(self, otherClass=None, **kw):
        kw['otherClass'] = otherClass
        self.kw = kw
        self._joinMethodName = self.kw.pop('joinMethodName', None)
        self.creationOrder = next(creationOrder)

    def _set_joinMethodName(self, value):
        assert self._joinMethodName == value or self._joinMethodName is None, \
            "You have already given an explicit joinMethodName (%s), " \
            "and you are now setting it to %s" % (self._joinMethodName, value)
        self._joinMethodName = value

    def _get_joinMethodName(self):
        return self._joinMethodName

    joinMethodName = property(_get_joinMethodName, _set_joinMethodName)
    name = joinMethodName

    def withClass(self, soClass):
        if 'joinMethodName' in self.kw:
            self._joinMethodName = self.kw['joinMethodName']
            del self.kw['joinMethodName']
        return self.baseClass(creationOrder=self.creationOrder,
                              soClass=soClass,
                              joinDef=self,
                              joinMethodName=self._joinMethodName,
                              **self.kw)


# A join is separate from a foreign key, i.e., it is
# many-to-many, or one-to-many where the *other* class
# has the foreign key.


class SOJoin(object):

    def __init__(self,
                 creationOrder,
                 soClass=None,
                 otherClass=None,
                 joinColumn=None,
                 joinMethodName=None,
                 orderBy=NoDefault,
                 joinDef=None):
        self.creationOrder = creationOrder
        self.soClass = soClass
        self.joinDef = joinDef
        self.otherClassName = otherClass
        classregistry.registry(soClass.sqlmeta.registry).addClassCallback(
            otherClass, self._setOtherClass)
        self.joinColumn = joinColumn
        self.joinMethodName = joinMethodName
        self._orderBy = orderBy
        if not self.joinColumn:
            # Here we set up the basic join, which is
            # one-to-many, where the other class points to
            # us.
            self.joinColumn = styles.getStyle(
                self.soClass).tableReference(self.soClass.sqlmeta.table)

    def orderBy(self):
        if self._orderBy is NoDefault:
            self._orderBy = self.otherClass.sqlmeta.defaultOrder
        return self._orderBy
    orderBy = property(orderBy)

    def _setOtherClass(self, cls):
        self.otherClass = cls

    def hasIntermediateTable(self):
        return False

    def _applyOrderBy(self, results, defaultSortClass):
        if self.orderBy is not None:
            doSort(results, self.orderBy)
        return results


class MinType(object):
    """Sort less than everything, for handling None's in the results"""
    # functools.total_ordering would simplify this

    def __lt__(self, other):
        if self is other:
            return False
        return True

    def __eq__(self, other):
        return self is other

    def __gt__(self, other):
        return False

    def __le__(self, other):
        return True

    def __ge__(self, other):
        if self is other:
            return True
        return False


Min = MinType()


def doSort(results, orderBy):
    if isinstance(orderBy, (tuple, list)):
        if len(orderBy) == 1:
            orderBy = orderBy[0]
        else:
            # Rely on stable sort results, since this is simpler
            # than trying to munge everything into a single sort key
            doSort(results, orderBy[0])
            doSort(results, orderBy[1:])
            return
    if isinstance(orderBy, sqlbuilder.DESC) \
       and isinstance(orderBy.expr, sqlbuilder.SQLObjectField):
        orderBy = '-' + orderBy.expr.original
    elif isinstance(orderBy, sqlbuilder.SQLObjectField):
        orderBy = orderBy.original
    # @@: but we don't handle more complex expressions for orderings
    if orderBy.startswith('-'):
        orderBy = orderBy[1:]
        reverse = True
    else:
        reverse = False

    def sortkey(x, attr=orderBy):
        a = getattr(x, attr)
        if a is None:
            return Min
        return a
    results.sort(key=sortkey, reverse=reverse)


# This is a one-to-many


class SOMultipleJoin(SOJoin):

    def __init__(self, addRemoveName=None, **kw):
        # addRemovePrefix is something like @@
        SOJoin.__init__(self, **kw)

        # Here we generate the method names
        if not self.joinMethodName:
            name = self.otherClassName[0].lower() + self.otherClassName[1:]
            if name.endswith('s'):
                name += "es"
            else:
                name += "s"
            self.joinMethodName = name
        if addRemoveName:
            self.addRemoveName = addRemoveName
        else:
            self.addRemoveName = capword(self.otherClassName)

    def performJoin(self, inst):
        ids = inst._connection._SO_selectJoin(
            self.otherClass,
            self.joinColumn,
            inst.id)
        if inst.sqlmeta._perConnection:
            conn = inst._connection
        else:
            conn = None
        return self._applyOrderBy(
            [self.otherClass.get(id, conn) for (id,) in ids if id is not None],
            self.otherClass)

    def _dbNameToPythonName(self):
        for column in self.otherClass.sqlmeta.columns.values():
            if column.dbName == self.joinColumn:
                return column.name
        return self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)


[docs]class MultipleJoin(Join): baseClass = SOMultipleJoin
class SOSQLMultipleJoin(SOMultipleJoin): def performJoin(self, inst): if inst.sqlmeta._perConnection: conn = inst._connection else: conn = None pythonColumn = self._dbNameToPythonName() results = self.otherClass.select( getattr(self.otherClass.q, pythonColumn) == inst.id, connection=conn) return results.orderBy(self.orderBy)
[docs]class SQLMultipleJoin(Join): baseClass = SOSQLMultipleJoin
# This is a many-to-many join, with an intermediary table class SORelatedJoin(SOMultipleJoin): def __init__(self, otherColumn=None, intermediateTable=None, createRelatedTable=True, **kw): self.intermediateTable = intermediateTable self.otherColumn = otherColumn self.createRelatedTable = createRelatedTable SOMultipleJoin.__init__(self, **kw) classregistry.registry( self.soClass.sqlmeta.registry).addClassCallback( self.otherClassName, self._setOtherRelatedClass) def _setOtherRelatedClass(self, otherClass): if not self.intermediateTable: names = [self.soClass.sqlmeta.table, otherClass.sqlmeta.table] names.sort() self.intermediateTable = '%s_%s' % (names[0], names[1]) if not self.otherColumn: self.otherColumn = self.soClass.sqlmeta.style.tableReference( otherClass.sqlmeta.table) def hasIntermediateTable(self): return True def performJoin(self, inst): ids = inst._connection._SO_intermediateJoin( self.intermediateTable, self.otherColumn, self.joinColumn, inst.id) if inst.sqlmeta._perConnection: conn = inst._connection else: conn = None return self._applyOrderBy( [self.otherClass.get(id, conn) for (id,) in ids if id is not None], self.otherClass) def remove(self, inst, other): inst._connection._SO_intermediateDelete( self.intermediateTable, self.joinColumn, getID(inst), self.otherColumn, getID(other)) def add(self, inst, other): inst._connection._SO_intermediateInsert( self.intermediateTable, self.joinColumn, getID(inst), self.otherColumn, getID(other))
[docs]class RelatedJoin(MultipleJoin): baseClass = SORelatedJoin
# helper classes to SQLRelatedJoin class OtherTableToJoin(sqlbuilder.SQLExpression): def __init__(self, otherTable, otherIdName, interTable, joinColumn, alias): self.otherTable = otherTable self.otherIdName = otherIdName self.interTable = interTable self.joinColumn = joinColumn self.alias = alias def tablesUsedImmediate(self): return [ '%s %s' % (self.otherTable, self.alias) if self.alias else self.otherTable, self.interTable, ] def __sqlrepr__(self, db): return '%s.%s = %s.%s' % ( self.alias if self.alias else self.otherTable, self.otherIdName, self.interTable, self.joinColumn) class JoinToTable(sqlbuilder.SQLExpression): def __init__(self, table, idName, interTable, joinColumn, alias): self.table = table self.idName = idName self.interTable = interTable self.joinColumn = joinColumn self.alias = alias def tablesUsedImmediate(self): return [ '%s %s' % (self.table, self.alias) if self.alias else self.table, self.interTable, ] def __sqlrepr__(self, db): return '%s.%s = %s.%s' % ( self.interTable, self.joinColumn, self.alias if self.alias else self.table, self.idName) class TableToId(sqlbuilder.SQLExpression): def __init__(self, table, idName, idValue, alias): self.table = table self.idName = idName self.idValue = idValue self.alias = alias def tablesUsedImmediate(self): return [ '%s %s' % (self.table, self.alias) if self.alias else self.table, ] def __sqlrepr__(self, db): return '%s.%s = %s' % ( self.alias if self.alias else self.table, self.idName, self.idValue) class SQLJoinSelectResults(sresults.SelectResults): def filter(self, filter_clause): clause_tables = filter_clause.tablesUsed(None) if self._SOSQLRelatedJoin_realSourceClass.sqlmeta.table \ in clause_tables: tableClass = self._SOSQLRelatedJoin_realSourceClass.__name__ raise ValueError( "Using table '%s' in the filter expression without an alias " "could produce wrong SQL. Most probably you need " "Alias(%s, '_SO_SQLRelatedJoin_OtherTable') instead." % (tableClass, tableClass) ) return sresults.SelectResults.filter(self, filter_clause) class SOSQLRelatedJoin(SORelatedJoin): def performJoin(self, inst): if inst.sqlmeta._perConnection: conn = inst._connection else: conn = None needAlias = self.soClass is self.otherClass if needAlias: source = sqlbuilder.Alias( self.otherClass, '_SO_SQLRelatedJoin_OtherTable') sresultsClass = SQLJoinSelectResults else: source = self.otherClass sresultsClass = self.otherClass.SelectResultsClass results = sresultsClass( source, sqlbuilder.AND( OtherTableToJoin( self.otherClass.sqlmeta.table, self.otherClass.sqlmeta.idName, self.intermediateTable, self.otherColumn, '_SO_SQLRelatedJoin_OtherTable' if needAlias else '', ), JoinToTable( self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName, self.intermediateTable, self.joinColumn, '_SO_SQLRelatedJoin_ThisTable' if needAlias else '', ), TableToId( self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName, inst.id, '_SO_SQLRelatedJoin_ThisTable' if needAlias else '', ), ), clauseTables=( '%s _SO_SQLRelatedJoin_ThisTable' % self.soClass.sqlmeta.table if needAlias else self.soClass.sqlmeta.table, '%s _SO_SQLRelatedJoin_OtherTable' % self.otherClass.sqlmeta.table if needAlias else self.otherClass.sqlmeta.table, self.intermediateTable, ), connection=conn, orderBy=self.orderBy, ) results._SOSQLRelatedJoin_realSourceClass = self.otherClass return results
[docs]class SQLRelatedJoin(RelatedJoin): baseClass = SOSQLRelatedJoin
class SOSingleJoin(SOMultipleJoin): def __init__(self, **kw): self.makeDefault = kw.pop('makeDefault', False) SOMultipleJoin.__init__(self, **kw) def performJoin(self, inst): if inst.sqlmeta._perConnection: conn = inst._connection else: conn = None pythonColumn = self._dbNameToPythonName() results = self.otherClass.select( getattr(self.otherClass.q, pythonColumn) == inst.id, connection=conn ) if results.count() == 0: if not self.makeDefault: return None else: kw = {self.soClass.sqlmeta.style. instanceIDAttrToAttr(pythonColumn): inst} # instanciating the otherClass with all return self.otherClass(**kw) else: return results[0]
[docs]class SingleJoin(Join): baseClass = SOSingleJoin
class SOManyToMany(object): def __init__(self, soClass, name, join, intermediateTable, joinColumn, otherColumn, createJoinTable, **attrs): self.name = name self.intermediateTable = intermediateTable self.joinColumn = joinColumn self.otherColumn = otherColumn self.createJoinTable = createJoinTable self.soClass = self.otherClass = None for name, value in attrs.items(): setattr(self, name, value) classregistry.registry( soClass.sqlmeta.registry).addClassCallback( join, self._setOtherClass) classregistry.registry( soClass.sqlmeta.registry).addClassCallback( soClass.__name__, self._setThisClass) def _setThisClass(self, soClass): self.soClass = soClass if self.soClass and self.otherClass: self._finishSet() def _setOtherClass(self, otherClass): self.otherClass = otherClass if self.soClass and self.otherClass: self._finishSet() def _finishSet(self): if self.intermediateTable is None: names = [self.soClass.sqlmeta.table, self.otherClass.sqlmeta.table] names.sort() self.intermediateTable = '%s_%s' % (names[0], names[1]) if not self.otherColumn: self.otherColumn = self.soClass.sqlmeta.style.tableReference( self.otherClass.sqlmeta.table) if not self.joinColumn: self.joinColumn = styles.getStyle( self.soClass).tableReference(self.soClass.sqlmeta.table) events.listen(self.event_CreateTableSignal, self.soClass, events.CreateTableSignal) events.listen(self.event_CreateTableSignal, self.otherClass, events.CreateTableSignal) self.clause = ( (self.otherClass.q.id == sqlbuilder.Field(self.intermediateTable, self.otherColumn)) & (sqlbuilder.Field(self.intermediateTable, self.joinColumn) == self.soClass.q.id)) def __get__(self, obj, type): if obj is None: return self query = ( (self.otherClass.q.id == sqlbuilder.Field(self.intermediateTable, self.otherColumn)) & (sqlbuilder.Field(self.intermediateTable, self.joinColumn) == obj.id)) select = self.otherClass.select(query) return _ManyToManySelectWrapper(obj, self, select) def event_CreateTableSignal(self, soClass, connection, extra_sql, post_funcs): if self.createJoinTable: post_funcs.append(self.event_CreateTableSignalPost) def event_CreateTableSignalPost(self, soClass, connection): if connection.tableExists(self.intermediateTable): return connection._SO_createJoinTable(self)
[docs]class ManyToMany(boundattributes.BoundFactory): factory_class = SOManyToMany __restrict_attributes__ = ( 'join', 'intermediateTable', 'joinColumn', 'otherColumn', 'createJoinTable') __unpackargs__ = ('join',) # Default values: intermediateTable = None joinColumn = None otherColumn = None createJoinTable = True
class _ManyToManySelectWrapper(object): def __init__(self, forObject, join, select): self.forObject = forObject self.join = join self.select = select def __getattr__(self, attr): # @@: This passes through private variable access too... should it? # Also magic methods, like __str__ return getattr(self.select, attr) def __repr__(self): return '<%s for: %s>' % (self.__class__.__name__, repr(self.select)) def __str__(self): return str(self.select) def __iter__(self): return iter(self.select) def __getitem__(self, key): return self.select[key] def add(self, obj): obj._connection._SO_intermediateInsert( self.join.intermediateTable, self.join.joinColumn, getID(self.forObject), self.join.otherColumn, getID(obj)) def remove(self, obj): obj._connection._SO_intermediateDelete( self.join.intermediateTable, self.join.joinColumn, getID(self.forObject), self.join.otherColumn, getID(obj)) def create(self, **kw): obj = self.join.otherClass(**kw) self.add(obj) return obj class SOOneToMany(object): def __init__(self, soClass, name, join, joinColumn, **attrs): self.soClass = soClass self.name = name self.joinColumn = joinColumn for name, value in attrs.items(): setattr(self, name, value) classregistry.registry( soClass.sqlmeta.registry).addClassCallback( join, self._setOtherClass) def _setOtherClass(self, otherClass): self.otherClass = otherClass if not self.joinColumn: self.joinColumn = styles.getStyle( self.soClass).tableReference(self.soClass.sqlmeta.table) self.clause = ( sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn) == self.soClass.q.id) def __get__(self, obj, type): if obj is None: return self query = ( sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn) == obj.id) select = self.otherClass.select(query) return _OneToManySelectWrapper(obj, self, select)
[docs]class OneToMany(boundattributes.BoundFactory): factory_class = SOOneToMany __restrict_attributes__ = ( 'join', 'joinColumn') __unpackargs__ = ('join',) # Default values: joinColumn = None
class _OneToManySelectWrapper(object): def __init__(self, forObject, join, select): self.forObject = forObject self.join = join self.select = select def __getattr__(self, attr): # @@: This passes through private variable access too... should it? # Also magic methods, like __str__ return getattr(self.select, attr) def __repr__(self): return '<%s for: %s>' % (self.__class__.__name__, repr(self.select)) def __str__(self): return str(self.select) def __iter__(self): return iter(self.select) def __getitem__(self, key): return self.select[key] def create(self, **kw): kw[self.join.joinColumn] = self.forObject.id return self.join.otherClass(**kw)