from itertools import count
from . import boundattributes
from . import classregistry
from . import events
from . import sresults
from . import styles
from . import sqlbuilder
from .styles import capword
__all__ = ['MultipleJoin', 'SQLMultipleJoin', 'RelatedJoin', 'SQLRelatedJoin',
'SingleJoin', 'ManyToMany', 'OneToMany']
creationOrder = count()
NoDefault = sqlbuilder.NoDefault
def getID(obj):
try:
return obj.id
except AttributeError:
return int(obj)
class Join(object):
def __init__(self, otherClass=None, **kw):
kw['otherClass'] = otherClass
self.kw = kw
self._joinMethodName = self.kw.pop('joinMethodName', None)
self.creationOrder = next(creationOrder)
def _set_joinMethodName(self, value):
assert self._joinMethodName == value or self._joinMethodName is None, \
"You have already given an explicit joinMethodName (%s), " \
"and you are now setting it to %s" % (self._joinMethodName, value)
self._joinMethodName = value
def _get_joinMethodName(self):
return self._joinMethodName
joinMethodName = property(_get_joinMethodName, _set_joinMethodName)
name = joinMethodName
def withClass(self, soClass):
if 'joinMethodName' in self.kw:
self._joinMethodName = self.kw['joinMethodName']
del self.kw['joinMethodName']
return self.baseClass(creationOrder=self.creationOrder,
soClass=soClass,
joinDef=self,
joinMethodName=self._joinMethodName,
**self.kw)
# A join is separate from a foreign key, i.e., it is
# many-to-many, or one-to-many where the *other* class
# has the foreign key.
class SOJoin(object):
def __init__(self,
creationOrder,
soClass=None,
otherClass=None,
joinColumn=None,
joinMethodName=None,
orderBy=NoDefault,
joinDef=None):
self.creationOrder = creationOrder
self.soClass = soClass
self.joinDef = joinDef
self.otherClassName = otherClass
classregistry.registry(soClass.sqlmeta.registry).addClassCallback(
otherClass, self._setOtherClass)
self.joinColumn = joinColumn
self.joinMethodName = joinMethodName
self._orderBy = orderBy
if not self.joinColumn:
# Here we set up the basic join, which is
# one-to-many, where the other class points to
# us.
self.joinColumn = styles.getStyle(
self.soClass).tableReference(self.soClass.sqlmeta.table)
def orderBy(self):
if self._orderBy is NoDefault:
self._orderBy = self.otherClass.sqlmeta.defaultOrder
return self._orderBy
orderBy = property(orderBy)
def _setOtherClass(self, cls):
self.otherClass = cls
def hasIntermediateTable(self):
return False
def _applyOrderBy(self, results, defaultSortClass):
if self.orderBy is not None:
doSort(results, self.orderBy)
return results
class MinType(object):
"""Sort less than everything, for handling None's in the results"""
# functools.total_ordering would simplify this
def __lt__(self, other):
if self is other:
return False
return True
def __eq__(self, other):
return self is other
def __gt__(self, other):
return False
def __le__(self, other):
return True
def __ge__(self, other):
if self is other:
return True
return False
Min = MinType()
def doSort(results, orderBy):
if isinstance(orderBy, (tuple, list)):
if len(orderBy) == 1:
orderBy = orderBy[0]
else:
# Rely on stable sort results, since this is simpler
# than trying to munge everything into a single sort key
doSort(results, orderBy[0])
doSort(results, orderBy[1:])
return
if isinstance(orderBy, sqlbuilder.DESC) \
and isinstance(orderBy.expr, sqlbuilder.SQLObjectField):
orderBy = '-' + orderBy.expr.original
elif isinstance(orderBy, sqlbuilder.SQLObjectField):
orderBy = orderBy.original
# @@: but we don't handle more complex expressions for orderings
if orderBy.startswith('-'):
orderBy = orderBy[1:]
reverse = True
else:
reverse = False
def sortkey(x, attr=orderBy):
a = getattr(x, attr)
if a is None:
return Min
return a
results.sort(key=sortkey, reverse=reverse)
# This is a one-to-many
class SOMultipleJoin(SOJoin):
def __init__(self, addRemoveName=None, **kw):
# addRemovePrefix is something like @@
SOJoin.__init__(self, **kw)
# Here we generate the method names
if not self.joinMethodName:
name = self.otherClassName[0].lower() + self.otherClassName[1:]
if name.endswith('s'):
name += "es"
else:
name += "s"
self.joinMethodName = name
if addRemoveName:
self.addRemoveName = addRemoveName
else:
self.addRemoveName = capword(self.otherClassName)
def performJoin(self, inst):
ids = inst._connection._SO_selectJoin(
self.otherClass,
self.joinColumn,
inst.id)
if inst.sqlmeta._perConnection:
conn = inst._connection
else:
conn = None
return self._applyOrderBy(
[self.otherClass.get(id, conn) for (id,) in ids if id is not None],
self.otherClass)
def _dbNameToPythonName(self):
for column in self.otherClass.sqlmeta.columns.values():
if column.dbName == self.joinColumn:
return column.name
return self.soClass.sqlmeta.style.dbColumnToPythonAttr(self.joinColumn)
[docs]class MultipleJoin(Join):
baseClass = SOMultipleJoin
class SOSQLMultipleJoin(SOMultipleJoin):
def performJoin(self, inst):
if inst.sqlmeta._perConnection:
conn = inst._connection
else:
conn = None
pythonColumn = self._dbNameToPythonName()
results = self.otherClass.select(
getattr(self.otherClass.q, pythonColumn) == inst.id,
connection=conn)
return results.orderBy(self.orderBy)
[docs]class SQLMultipleJoin(Join):
baseClass = SOSQLMultipleJoin
# This is a many-to-many join, with an intermediary table
class SORelatedJoin(SOMultipleJoin):
def __init__(self,
otherColumn=None,
intermediateTable=None,
createRelatedTable=True,
**kw):
self.intermediateTable = intermediateTable
self.otherColumn = otherColumn
self.createRelatedTable = createRelatedTable
SOMultipleJoin.__init__(self, **kw)
classregistry.registry(
self.soClass.sqlmeta.registry).addClassCallback(
self.otherClassName, self._setOtherRelatedClass)
def _setOtherRelatedClass(self, otherClass):
if not self.intermediateTable:
names = [self.soClass.sqlmeta.table,
otherClass.sqlmeta.table]
names.sort()
self.intermediateTable = '%s_%s' % (names[0], names[1])
if not self.otherColumn:
self.otherColumn = self.soClass.sqlmeta.style.tableReference(
otherClass.sqlmeta.table)
def hasIntermediateTable(self):
return True
def performJoin(self, inst):
ids = inst._connection._SO_intermediateJoin(
self.intermediateTable,
self.otherColumn,
self.joinColumn,
inst.id)
if inst.sqlmeta._perConnection:
conn = inst._connection
else:
conn = None
return self._applyOrderBy(
[self.otherClass.get(id, conn) for (id,) in ids if id is not None],
self.otherClass)
def remove(self, inst, other):
inst._connection._SO_intermediateDelete(
self.intermediateTable,
self.joinColumn,
getID(inst),
self.otherColumn,
getID(other))
def add(self, inst, other):
inst._connection._SO_intermediateInsert(
self.intermediateTable,
self.joinColumn,
getID(inst),
self.otherColumn,
getID(other))
# helper classes to SQLRelatedJoin
class OtherTableToJoin(sqlbuilder.SQLExpression):
def __init__(self, otherTable, otherIdName, interTable, joinColumn, alias):
self.otherTable = otherTable
self.otherIdName = otherIdName
self.interTable = interTable
self.joinColumn = joinColumn
self.alias = alias
def tablesUsedImmediate(self):
return [
'%s %s' % (self.otherTable, self.alias)
if self.alias else self.otherTable,
self.interTable,
]
def __sqlrepr__(self, db):
return '%s.%s = %s.%s' % (
self.alias if self.alias else self.otherTable,
self.otherIdName, self.interTable, self.joinColumn)
class JoinToTable(sqlbuilder.SQLExpression):
def __init__(self, table, idName, interTable, joinColumn, alias):
self.table = table
self.idName = idName
self.interTable = interTable
self.joinColumn = joinColumn
self.alias = alias
def tablesUsedImmediate(self):
return [
'%s %s' % (self.table, self.alias)
if self.alias else self.table,
self.interTable,
]
def __sqlrepr__(self, db):
return '%s.%s = %s.%s' % (
self.interTable, self.joinColumn,
self.alias if self.alias else self.table, self.idName)
class TableToId(sqlbuilder.SQLExpression):
def __init__(self, table, idName, idValue, alias):
self.table = table
self.idName = idName
self.idValue = idValue
self.alias = alias
def tablesUsedImmediate(self):
return [
'%s %s' % (self.table, self.alias)
if self.alias else self.table,
]
def __sqlrepr__(self, db):
return '%s.%s = %s' % (
self.alias if self.alias else self.table,
self.idName, self.idValue)
class SQLJoinSelectResults(sresults.SelectResults):
def filter(self, filter_clause):
clause_tables = filter_clause.tablesUsed(None)
if self._SOSQLRelatedJoin_realSourceClass.sqlmeta.table \
in clause_tables:
tableClass = self._SOSQLRelatedJoin_realSourceClass.__name__
raise ValueError(
"Using table '%s' in the filter expression without an alias "
"could produce wrong SQL. Most probably you need "
"Alias(%s, '_SO_SQLRelatedJoin_OtherTable') instead."
% (tableClass, tableClass)
)
return sresults.SelectResults.filter(self, filter_clause)
class SOSQLRelatedJoin(SORelatedJoin):
def performJoin(self, inst):
if inst.sqlmeta._perConnection:
conn = inst._connection
else:
conn = None
needAlias = self.soClass is self.otherClass
if needAlias:
source = sqlbuilder.Alias(
self.otherClass, '_SO_SQLRelatedJoin_OtherTable')
sresultsClass = SQLJoinSelectResults
else:
source = self.otherClass
sresultsClass = self.otherClass.SelectResultsClass
results = sresultsClass(
source,
sqlbuilder.AND(
OtherTableToJoin(
self.otherClass.sqlmeta.table,
self.otherClass.sqlmeta.idName,
self.intermediateTable, self.otherColumn,
'_SO_SQLRelatedJoin_OtherTable' if needAlias else '',
),
JoinToTable(
self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName,
self.intermediateTable, self.joinColumn,
'_SO_SQLRelatedJoin_ThisTable' if needAlias else '',
),
TableToId(
self.soClass.sqlmeta.table, self.soClass.sqlmeta.idName,
inst.id,
'_SO_SQLRelatedJoin_ThisTable' if needAlias else '',
),
),
clauseTables=(
'%s _SO_SQLRelatedJoin_ThisTable' % self.soClass.sqlmeta.table
if needAlias else self.soClass.sqlmeta.table,
'%s _SO_SQLRelatedJoin_OtherTable' %
self.otherClass.sqlmeta.table
if needAlias else self.otherClass.sqlmeta.table,
self.intermediateTable,
),
connection=conn,
orderBy=self.orderBy,
)
results._SOSQLRelatedJoin_realSourceClass = self.otherClass
return results
class SOSingleJoin(SOMultipleJoin):
def __init__(self, **kw):
self.makeDefault = kw.pop('makeDefault', False)
SOMultipleJoin.__init__(self, **kw)
def performJoin(self, inst):
if inst.sqlmeta._perConnection:
conn = inst._connection
else:
conn = None
pythonColumn = self._dbNameToPythonName()
results = self.otherClass.select(
getattr(self.otherClass.q, pythonColumn) == inst.id,
connection=conn
)
if results.count() == 0:
if not self.makeDefault:
return None
else:
kw = {self.soClass.sqlmeta.style.
instanceIDAttrToAttr(pythonColumn): inst}
# instanciating the otherClass with all
return self.otherClass(**kw)
else:
return results[0]
[docs]class SingleJoin(Join):
baseClass = SOSingleJoin
class SOManyToMany(object):
def __init__(self, soClass, name, join,
intermediateTable, joinColumn, otherColumn,
createJoinTable, **attrs):
self.name = name
self.intermediateTable = intermediateTable
self.joinColumn = joinColumn
self.otherColumn = otherColumn
self.createJoinTable = createJoinTable
self.soClass = self.otherClass = None
for name, value in attrs.items():
setattr(self, name, value)
classregistry.registry(
soClass.sqlmeta.registry).addClassCallback(
join, self._setOtherClass)
classregistry.registry(
soClass.sqlmeta.registry).addClassCallback(
soClass.__name__, self._setThisClass)
def _setThisClass(self, soClass):
self.soClass = soClass
if self.soClass and self.otherClass:
self._finishSet()
def _setOtherClass(self, otherClass):
self.otherClass = otherClass
if self.soClass and self.otherClass:
self._finishSet()
def _finishSet(self):
if self.intermediateTable is None:
names = [self.soClass.sqlmeta.table,
self.otherClass.sqlmeta.table]
names.sort()
self.intermediateTable = '%s_%s' % (names[0], names[1])
if not self.otherColumn:
self.otherColumn = self.soClass.sqlmeta.style.tableReference(
self.otherClass.sqlmeta.table)
if not self.joinColumn:
self.joinColumn = styles.getStyle(
self.soClass).tableReference(self.soClass.sqlmeta.table)
events.listen(self.event_CreateTableSignal,
self.soClass, events.CreateTableSignal)
events.listen(self.event_CreateTableSignal,
self.otherClass, events.CreateTableSignal)
self.clause = (
(self.otherClass.q.id
== sqlbuilder.Field(self.intermediateTable, self.otherColumn))
& (sqlbuilder.Field(self.intermediateTable, self.joinColumn)
== self.soClass.q.id))
def __get__(self, obj, type):
if obj is None:
return self
query = (
(self.otherClass.q.id
== sqlbuilder.Field(self.intermediateTable, self.otherColumn))
& (sqlbuilder.Field(self.intermediateTable, self.joinColumn)
== obj.id))
select = self.otherClass.select(query)
return _ManyToManySelectWrapper(obj, self, select)
def event_CreateTableSignal(self, soClass, connection, extra_sql,
post_funcs):
if self.createJoinTable:
post_funcs.append(self.event_CreateTableSignalPost)
def event_CreateTableSignalPost(self, soClass, connection):
if connection.tableExists(self.intermediateTable):
return
connection._SO_createJoinTable(self)
[docs]class ManyToMany(boundattributes.BoundFactory):
factory_class = SOManyToMany
__restrict_attributes__ = (
'join', 'intermediateTable',
'joinColumn', 'otherColumn', 'createJoinTable')
__unpackargs__ = ('join',)
# Default values:
intermediateTable = None
joinColumn = None
otherColumn = None
createJoinTable = True
class _ManyToManySelectWrapper(object):
def __init__(self, forObject, join, select):
self.forObject = forObject
self.join = join
self.select = select
def __getattr__(self, attr):
# @@: This passes through private variable access too... should it?
# Also magic methods, like __str__
return getattr(self.select, attr)
def __repr__(self):
return '<%s for: %s>' % (self.__class__.__name__, repr(self.select))
def __str__(self):
return str(self.select)
def __iter__(self):
return iter(self.select)
def __getitem__(self, key):
return self.select[key]
def add(self, obj):
obj._connection._SO_intermediateInsert(
self.join.intermediateTable,
self.join.joinColumn,
getID(self.forObject),
self.join.otherColumn,
getID(obj))
def remove(self, obj):
obj._connection._SO_intermediateDelete(
self.join.intermediateTable,
self.join.joinColumn,
getID(self.forObject),
self.join.otherColumn,
getID(obj))
def create(self, **kw):
obj = self.join.otherClass(**kw)
self.add(obj)
return obj
class SOOneToMany(object):
def __init__(self, soClass, name, join, joinColumn, **attrs):
self.soClass = soClass
self.name = name
self.joinColumn = joinColumn
for name, value in attrs.items():
setattr(self, name, value)
classregistry.registry(
soClass.sqlmeta.registry).addClassCallback(
join, self._setOtherClass)
def _setOtherClass(self, otherClass):
self.otherClass = otherClass
if not self.joinColumn:
self.joinColumn = styles.getStyle(
self.soClass).tableReference(self.soClass.sqlmeta.table)
self.clause = (
sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn)
== self.soClass.q.id)
def __get__(self, obj, type):
if obj is None:
return self
query = (
sqlbuilder.Field(self.otherClass.sqlmeta.table, self.joinColumn)
== obj.id)
select = self.otherClass.select(query)
return _OneToManySelectWrapper(obj, self, select)
[docs]class OneToMany(boundattributes.BoundFactory):
factory_class = SOOneToMany
__restrict_attributes__ = (
'join', 'joinColumn')
__unpackargs__ = ('join',)
# Default values:
joinColumn = None
class _OneToManySelectWrapper(object):
def __init__(self, forObject, join, select):
self.forObject = forObject
self.join = join
self.select = select
def __getattr__(self, attr):
# @@: This passes through private variable access too... should it?
# Also magic methods, like __str__
return getattr(self.select, attr)
def __repr__(self):
return '<%s for: %s>' % (self.__class__.__name__, repr(self.select))
def __str__(self):
return str(self.select)
def __iter__(self):
return iter(self.select)
def __getitem__(self, key):
return self.select[key]
def create(self, **kw):
kw[self.join.joinColumn] = self.forObject.id
return self.join.otherClass(**kw)