0001import threading
0002from util.threadinglocal import local as threading_local
0003import sys
0004import re
0005import warnings
0006import atexit
0007import os
0008import new
0009import types
0010import urllib
0011import weakref
0012import inspect
0013import sqlbuilder
0014from cache import CacheSet
0015import col
0016import main
0017from joins import sorter
0018from converters import sqlrepr
0019import classregistry
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026 conn = ref()
0027 if conn is not None:
0028 conn.close()
0029
0030class ConsoleWriter:
0031 def __init__(self, connection, loglevel):
0032
0033 self.loglevel = loglevel or "stdout"
0034 self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035 def write(self, text):
0036 logfile = getattr(sys, self.loglevel)
0037 if isinstance(text, unicode):
0038 try:
0039 text = text.encode(self.dbEncoding)
0040 except UnicodeEncodeError:
0041 text = repr(text)[2:-1]
0042 logfile.write(text + '\n')
0043
0044class LogWriter:
0045 def __init__(self, connection, logger, loglevel):
0046 self.logger = logger
0047 self.loglevel = loglevel
0048 self.logmethod = getattr(logger, loglevel)
0049 def write(self, text):
0050 self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053 if not loggerName:
0054 return ConsoleWriter(connection, loglevel)
0055 import logging
0056 logger = logging.getLogger(loggerName)
0057 return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060 """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061 _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062 '0': False, 'no': False, 'false': False, 'off': False}
0063 def __new__(cls, value):
0064 try:
0065 return Boolean._keywords[value.lower()]
0066 except (AttributeError, KeyError):
0067 return bool(value)
0068
0069class DBConnection:
0070
0071 def __init__(self, name=None, debug=False, debugOutput=False,
0072 cache=True, style=None, autoCommit=True,
0073 debugThreading=False, registry=None,
0074 logger=None, loglevel=None):
0075 self.name = name
0076 self.debug = Boolean(debug)
0077 self.debugOutput = Boolean(debugOutput)
0078 self.debugThreading = Boolean(debugThreading)
0079 self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080 self.doCache = Boolean(cache)
0081 self.cache = CacheSet(cache=self.doCache)
0082 self.style = style
0083 self._connectionNumbers = {}
0084 self._connectionCount = 1
0085 self.autoCommit = Boolean(autoCommit)
0086 self.registry = registry or None
0087 classregistry.registry(self.registry).addCallback(
0088 self.soClassAdded)
0089 registerConnectionInstance(self)
0090 atexit.register(_closeConnection, weakref.ref(self))
0091
0092 def uri(self):
0093 auth = getattr(self, 'user', '')
0094 if auth:
0095 if self.password:
0096 auth = auth + ':' + self.password
0097 auth = auth + '@'
0098 else:
0099 assert not getattr(self, 'password', None), (
0100 'URIs cannot express passwords without usernames')
0101 uri = '%s://%s' % (self.dbName, auth)
0102 if self.host:
0103 uri += self.host
0104 uri += '/'
0105 db = self.db
0106 if db.startswith('/'):
0107 db = path[1:]
0108 return uri + db
0109
0110 def connectionFromURI(cls, uri):
0111 raise NotImplemented
0112 connectionFromURI = classmethod(connectionFromURI)
0113
0114 def _parseURI(uri):
0115 schema, rest = uri.split(':', 1)
0116 assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0117 if rest.startswith('/') and not rest.startswith('//'):
0118 host = None
0119 rest = rest[1:]
0120 elif rest.startswith('///'):
0121 host = None
0122 rest = rest[3:]
0123 else:
0124 rest = rest[2:]
0125 if rest.find('/') == -1:
0126 host = rest
0127 rest = ''
0128 else:
0129 host, rest = rest.split('/', 1)
0130 if host and host.find('@') != -1:
0131 user, host = host.rsplit('@', 1)
0132 if user.find(':') != -1:
0133 user, password = user.split(':', 1)
0134 else:
0135 password = None
0136 else:
0137 user = password = None
0138 if host and host.find(':') != -1:
0139 _host, port = host.split(':')
0140 try:
0141 port = int(port)
0142 except ValueError:
0143 raise ValueError, "port must be integer, got '%s' instead" % port
0144 if not (1 <= port <= 65535):
0145 raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0146 host = _host
0147 else:
0148 port = None
0149 path = '/' + rest
0150 if os.name == 'nt':
0151 if (len(rest) > 1) and (rest[1] == '|'):
0152 path = "%s:%s" % (rest[0], rest[2:])
0153 args = {}
0154 if path.find('?') != -1:
0155 path, arglist = path.split('?', 1)
0156 arglist = arglist.split('&')
0157 for single in arglist:
0158 argname, argvalue = single.split('=', 1)
0159 argvalue = urllib.unquote(argvalue)
0160 args[argname] = argvalue
0161 return user, password, host, port, path, args
0162 _parseURI = staticmethod(_parseURI)
0163
0164 def soClassAdded(self, soClass):
0165 """
0166 This is called for each new class; we use this opportunity
0167 to create an instance method that is bound to the class
0168 and this connection.
0169 """
0170 name = soClass.__name__
0171 assert not hasattr(self, name), (
0172 "Connection %r already has an attribute with the name "
0173 "%r (and you just created the conflicting class %r)"
0174 % (self, name, soClass))
0175 setattr(self, name, ConnWrapper(soClass, self))
0176
0177 def expireAll(self):
0178 """
0179 Expire all instances of objects for this connection.
0180 """
0181 cache_set = self.cache
0182 cache_set.weakrefAll()
0183 for item in cache_set.getAll():
0184 item.expire()
0185
0186class ConnWrapper(object):
0187
0188 """
0189 This represents a SQLObject class that is bound to a specific
0190 connection (instances have a connection instance variable, but
0191 classes are global, so this is binds the connection variable
0192 lazily when a class method is accessed)
0193 """
0194
0195
0196
0197
0198 def __init__(self, soClass, connection):
0199 self._soClass = soClass
0200 self._connection = connection
0201
0202 def __call__(self, *args, **kw):
0203 kw['connection'] = self._connection
0204 return self._soClass(*args, **kw)
0205
0206 def __getattr__(self, attr):
0207 meth = getattr(self._soClass, attr)
0208 if not isinstance(meth, types.MethodType):
0209
0210 return meth
0211 try:
0212 takes_conn = meth.takes_connection
0213 except AttributeError:
0214 args, varargs, varkw, defaults = inspect.getargspec(meth)
0215 assert not varkw and not varargs, (
0216 "I cannot tell whether I must wrap this method, "
0217 "because it takes **kw: %r"
0218 % meth)
0219 takes_conn = 'connection' in args
0220 meth.im_func.takes_connection = takes_conn
0221 if not takes_conn:
0222 return meth
0223 return ConnMethodWrapper(meth, self._connection)
0224
0225class ConnMethodWrapper(object):
0226
0227 def __init__(self, method, connection):
0228 self._method = method
0229 self._connection = connection
0230
0231 def __getattr__(self, attr):
0232 return getattr(self._method, attr)
0233
0234 def __call__(self, *args, **kw):
0235 kw['connection'] = self._connection
0236 return self._method(*args, **kw)
0237
0238 def __repr__(self):
0239 return '<Wrapped %r with connection %r>' % (
0240 self._method, self._connection)
0241
0242class DBAPI(DBConnection):
0243
0244 """
0245 Subclass must define a `makeConnection()` method, which
0246 returns a newly-created connection object.
0247
0248 ``queryInsertID`` must also be defined.
0249 """
0250
0251 dbName = None
0252
0253 def __init__(self, **kw):
0254 self._pool = []
0255 self._poolLock = threading.Lock()
0256 DBConnection.__init__(self, **kw)
0257 self._binaryType = type(self.module.Binary(''))
0258
0259 def _runWithConnection(self, meth, *args):
0260 conn = self.getConnection()
0261 try:
0262 val = meth(conn, *args)
0263 finally:
0264 self.releaseConnection(conn)
0265 return val
0266
0267 def getConnection(self):
0268 self._poolLock.acquire()
0269 try:
0270 if not self._pool:
0271 conn = self.makeConnection()
0272 self._connectionNumbers[id(conn)] = self._connectionCount
0273 self._connectionCount += 1
0274 else:
0275 conn = self._pool.pop()
0276 if self.debug:
0277 s = 'ACQUIRE'
0278 if self._pool is not None:
0279 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0280 self.printDebug(conn, s, 'Pool')
0281 return conn
0282 finally:
0283 self._poolLock.release()
0284
0285 def releaseConnection(self, conn, explicit=False):
0286 if self.debug:
0287 if explicit:
0288 s = 'RELEASE (explicit)'
0289 else:
0290 s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0291 if self._pool is None:
0292 s += ' no pooling'
0293 else:
0294 s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0295 self.printDebug(conn, s, 'Pool')
0296 if self.supportTransactions and not explicit:
0297 if self.autoCommit == 'exception':
0298 if self.debug:
0299 self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0300 conn.rollback()
0301 raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0302 elif self.autoCommit:
0303 if self.debug:
0304 self.printDebug(conn, 'auto', 'COMMIT')
0305 if not getattr(conn, 'autocommit', False):
0306 conn.commit()
0307 else:
0308 if self.debug:
0309 self.printDebug(conn, 'auto', 'ROLLBACK')
0310 conn.rollback()
0311 if self._pool is not None:
0312 if conn not in self._pool:
0313
0314
0315
0316 self._pool.insert(0, conn)
0317 else:
0318 conn.close()
0319
0320 def printDebug(self, conn, s, name, type='query'):
0321 if name == 'Pool' and self.debug != 'Pool':
0322 return
0323 if type == 'query':
0324 sep = ': '
0325 else:
0326 sep = '->'
0327 s = repr(s)
0328 n = self._connectionNumbers[id(conn)]
0329 spaces = ' '*(8-len(name))
0330 if self.debugThreading:
0331 threadName = threading.currentThread().getName()
0332 threadName = (':' + threadName + ' '*(8-len(threadName)))
0333 else:
0334 threadName = ''
0335 msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0336 self.debugWriter.write(msg)
0337
0338 def _executeRetry(self, conn, cursor, query):
0339 if self.debug:
0340 self.printDebug(conn, query, 'QueryR')
0341 return cursor.execute(query)
0342
0343 def _query(self, conn, s):
0344 if self.debug:
0345 self.printDebug(conn, s, 'Query')
0346 self._executeRetry(conn, conn.cursor(), s)
0347
0348 def query(self, s):
0349 return self._runWithConnection(self._query, s)
0350
0351 def _queryAll(self, conn, s):
0352 if self.debug:
0353 self.printDebug(conn, s, 'QueryAll')
0354 c = conn.cursor()
0355 self._executeRetry(conn, c, s)
0356 value = c.fetchall()
0357 if self.debugOutput:
0358 self.printDebug(conn, value, 'QueryAll', 'result')
0359 return value
0360
0361 def queryAll(self, s):
0362 return self._runWithConnection(self._queryAll, s)
0363
0364 def _queryAllDescription(self, conn, s):
0365 """
0366 Like queryAll, but returns (description, rows), where the
0367 description is cursor.description (which gives row types)
0368 """
0369 if self.debug:
0370 self.printDebug(conn, s, 'QueryAllDesc')
0371 c = conn.cursor()
0372 self._executeRetry(conn, c, s)
0373 value = c.fetchall()
0374 if self.debugOutput:
0375 self.printDebug(conn, value, 'QueryAll', 'result')
0376 return c.description, value
0377
0378 def queryAllDescription(self, s):
0379 return self._runWithConnection(self._queryAllDescription, s)
0380
0381 def _queryOne(self, conn, s):
0382 if self.debug:
0383 self.printDebug(conn, s, 'QueryOne')
0384 c = conn.cursor()
0385 self._executeRetry(conn, c, s)
0386 value = c.fetchone()
0387 if self.debugOutput:
0388 self.printDebug(conn, value, 'QueryOne', 'result')
0389 return value
0390
0391 def queryOne(self, s):
0392 return self._runWithConnection(self._queryOne, s)
0393
0394 def _insertSQL(self, table, names, values):
0395 return ("INSERT INTO %s (%s) VALUES (%s)" %
0396 (table, ', '.join(names),
0397 ', '.join([self.sqlrepr(v) for v in values])))
0398
0399 def transaction(self):
0400 return Transaction(self)
0401
0402 def queryInsertID(self, soInstance, id, names, values):
0403 return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0404
0405 def iterSelect(self, select):
0406 return select.IterationClass(self, self.getConnection(),
0407 select, keepConnection=False)
0408
0409 def accumulateSelect(self, select, *expressions):
0410 """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0411 to the select object.
0412 """
0413 q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0414 q = self.sqlrepr(q)
0415 val = self.queryOne(q)
0416 if len(expressions) == 1:
0417 val = val[0]
0418 return val
0419
0420 def queryForSelect(self, select):
0421 return self.sqlrepr(select.queryForSelect())
0422
0423 def _SO_createJoinTable(self, join):
0424 self.query(self._SO_createJoinTableSQL(join))
0425
0426 def _SO_createJoinTableSQL(self, join):
0427 return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0428 (join.intermediateTable,
0429 join.joinColumn,
0430 self.joinSQLType(join),
0431 join.otherColumn,
0432 self.joinSQLType(join)))
0433
0434 def _SO_dropJoinTable(self, join):
0435 self.query("DROP TABLE %s" % join.intermediateTable)
0436
0437 def _SO_createIndex(self, soClass, index):
0438 self.query(self.createIndexSQL(soClass, index))
0439
0440 def createIndexSQL(self, soClass, index):
0441 assert 0, 'Implement in subclasses'
0442
0443 def createTable(self, soClass):
0444 createSql, constraints = self.createTableSQL(soClass)
0445 self.query(createSql)
0446
0447 return constraints
0448
0449 def createReferenceConstraints(self, soClass):
0450 refConstraints = [self.createReferenceConstraint(soClass, column) for column in soClass.sqlmeta.columnList if isinstance(column, col.SOForeignKey)]
0453 refConstraintDefs = [constraint for constraint in refConstraints if constraint]
0456 return refConstraintDefs
0457
0458 def createSQL(self, soClass):
0459 tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0460 if tableCreateSQLs:
0461 assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0462 '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0463 (soClass.__name__))
0464 if isinstance(tableCreateSQLs, dict):
0465 tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0466 if isinstance(tableCreateSQLs, str):
0467 tableCreateSQLs = [tableCreateSQLs]
0468 if isinstance(tableCreateSQLs, tuple):
0469 tableCreateSQLs = list(tableCreateSQLs)
0470 assert isinstance(tableCreateSQLs,list), (
0471 'Unable to create a list from %s.sqlmeta.createSQL' %
0472 (soClass.__name__))
0473 return tableCreateSQLs or []
0474
0475 def createTableSQL(self, soClass):
0476 constraints = self.createReferenceConstraints(soClass)
0477 extraSQL = self.createSQL(soClass)
0478 createSql = ('CREATE TABLE %s (\n%s\n)' %
0479 (soClass.sqlmeta.table, self.createColumns(soClass)))
0480 return createSql, constraints + extraSQL
0481
0482 def createColumns(self, soClass):
0483 columnDefs = [self.createIDColumn(soClass)] + [self.createColumn(soClass, col)
0485 for col in soClass.sqlmeta.columnList]
0486 return ",\n".join([" %s" % c for c in columnDefs])
0487
0488 def createReferenceConstraint(self, soClass, col):
0489 assert 0, "Implement in subclasses"
0490
0491 def createColumn(self, soClass, col):
0492 assert 0, "Implement in subclasses"
0493
0494 def dropTable(self, tableName, cascade=False):
0495 self.query("DROP TABLE %s" % tableName)
0496
0497 def clearTable(self, tableName):
0498
0499
0500
0501
0502 self.query("DELETE FROM %s" % tableName)
0503
0504 def createBinary(self, value):
0505 """
0506 Create a binary object wrapper for the given database.
0507 """
0508
0509 return self.module.Binary(value)
0510
0511
0512
0513
0514
0515
0516
0517 def _SO_update(self, so, values):
0518 self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0519 (so.sqlmeta.table,
0520 ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0521 for dbName, value in values]),
0522 so.sqlmeta.idName,
0523 self.sqlrepr(so.id)))
0524
0525 def _SO_selectOne(self, so, columnNames):
0526 return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0527
0528
0529 def _SO_selectOneAlt(self, so, columnNames, condition):
0530 if columnNames:
0531 columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0532 else:
0533 columns = None
0534 return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0535 staticTables=[so.sqlmeta.table],
0536 clause=condition)))
0537
0538 def _SO_delete(self, so):
0539 self.query("DELETE FROM %s WHERE %s = (%s)" %
0540 (so.sqlmeta.table,
0541 so.sqlmeta.idName,
0542 self.sqlrepr(so.id)))
0543
0544 def _SO_selectJoin(self, soClass, column, value):
0545 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0546 (soClass.sqlmeta.idName,
0547 soClass.sqlmeta.table,
0548 column,
0549 self.sqlrepr(value)))
0550
0551 def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0552 return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0553 (getColumn,
0554 table,
0555 joinColumn,
0556 self.sqlrepr(value)))
0557
0558 def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0559 secondColumn, secondValue):
0560 self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0561 (table,
0562 firstColumn,
0563 self.sqlrepr(firstValue),
0564 secondColumn,
0565 self.sqlrepr(secondValue)))
0566
0567 def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0568 secondColumn, secondValue):
0569 self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0570 (table,
0571 firstColumn,
0572 secondColumn,
0573 self.sqlrepr(firstValue),
0574 self.sqlrepr(secondValue)))
0575
0576 def _SO_columnClause(self, soClass, kw):
0577 ops = {None: "IS"}
0578 data = {}
0579 if 'id' in kw:
0580 data[soClass.sqlmeta.idName] = kw.pop('id')
0581 for key, col in soClass.sqlmeta.columns.items():
0582 if key in kw:
0583 value = kw.pop(key)
0584 if col.from_python:
0585 value = col.from_python(value, sqlbuilder.SQLObjectState(soClass))
0586 data[col.dbName] = value
0587 elif col.foreignName in kw:
0588 obj = kw.pop(col.foreignName)
0589 if isinstance(obj, main.SQLObject):
0590 data[col.dbName] = obj.id
0591 else:
0592 data[col.dbName] = obj
0593 if kw:
0594
0595 raise TypeError, "got an unexpected keyword argument(s): %r" % kw.keys()
0596
0597 if not data:
0598 return None
0599 return ' AND '.join(
0600 ['%s %s %s' %
0601 (dbName, ops.get(value, "="), self.sqlrepr(value))
0602 for dbName, value
0603 in data.items()])
0604
0605 def sqlrepr(self, v):
0606 return sqlrepr(v, self.dbName)
0607
0608 def __del__(self):
0609 self.close()
0610
0611 def close(self):
0612 if not hasattr(self, '_pool'):
0613
0614
0615 return
0616 if not self._pool:
0617 return
0618 self._poolLock.acquire()
0619 try:
0620 conns = self._pool[:]
0621 self._pool[:] = []
0622 for conn in conns:
0623 try:
0624 conn.close()
0625 except self.module.Error:
0626 pass
0627 del conn
0628 del conns
0629 finally:
0630 self._poolLock.release()
0631
0632 def createEmptyDatabase(self):
0633 """
0634 Create an empty database.
0635 """
0636 raise NotImplementedError
0637
0638class Iteration(object):
0639
0640 def __init__(self, dbconn, rawconn, select, keepConnection=False):
0641 self.dbconn = dbconn
0642 self.rawconn = rawconn
0643 self.select = select
0644 self.keepConnection = keepConnection
0645 self.cursor = rawconn.cursor()
0646 self.query = self.dbconn.queryForSelect(select)
0647 if dbconn.debug:
0648 dbconn.printDebug(rawconn, self.query, 'Select')
0649 self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0650
0651 def __iter__(self):
0652 return self
0653
0654 def next(self):
0655 result = self.cursor.fetchone()
0656 if result is None:
0657 self._cleanup()
0658 raise StopIteration
0659 if result[0] is None:
0660 return None
0661 if self.select.ops.get('lazyColumns', 0):
0662 obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0663 return obj
0664 else:
0665 obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0666 return obj
0667
0668 def _cleanup(self):
0669 if getattr(self, 'query', None) is None:
0670
0671 return
0672 self.query = None
0673 if not self.keepConnection:
0674 self.dbconn.releaseConnection(self.rawconn)
0675 self.dbconn = self.rawconn = self.select = self.cursor = None
0676
0677 def __del__(self):
0678 self._cleanup()
0679
0680class Transaction(object):
0681
0682 def __init__(self, dbConnection):
0683
0684 self._obsolete = True
0685 self._dbConnection = dbConnection
0686 self._connection = dbConnection.getConnection()
0687 self._dbConnection._setAutoCommit(self._connection, 0)
0688 self.cache = CacheSet(cache=dbConnection.doCache)
0689 self._deletedCache = {}
0690 self._obsolete = False
0691
0692 def assertActive(self):
0693 assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0694
0695 def query(self, s):
0696 self.assertActive()
0697 return self._dbConnection._query(self._connection, s)
0698
0699 def queryAll(self, s):
0700 self.assertActive()
0701 return self._dbConnection._queryAll(self._connection, s)
0702
0703 def queryOne(self, s):
0704 self.assertActive()
0705 return self._dbConnection._queryOne(self._connection, s)
0706
0707 def queryInsertID(self, soInstance, id, names, values):
0708 self.assertActive()
0709 return self._dbConnection._queryInsertID(
0710 self._connection, soInstance, id, names, values)
0711
0712 def iterSelect(self, select):
0713 self.assertActive()
0714
0715
0716
0717
0718
0719 return iter(list(select.IterationClass(self, self._connection,
0720 select, keepConnection=True)))
0721
0722 def _SO_delete(self, inst):
0723 cls = inst.__class__.__name__
0724 if not self._deletedCache.has_key(cls):
0725 self._deletedCache[cls] = []
0726 self._deletedCache[cls].append(inst.id)
0727 meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0728 return meth(inst)
0729
0730 def commit(self, close=False):
0731 if self._obsolete:
0732
0733 return
0734 if self._dbConnection.debug:
0735 self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0736 self._connection.commit()
0737 subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0738 subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0739 for cls, ids in subCaches:
0740 for id in ids:
0741 inst = self._dbConnection.cache.tryGetByName(id, cls)
0742 if inst is not None:
0743 inst.expire()
0744 if close:
0745 self._makeObsolete()
0746
0747 def rollback(self):
0748 if self._obsolete:
0749
0750 return
0751 if self._dbConnection.debug:
0752 self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0753 subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0754 self._connection.rollback()
0755
0756 for subCache, ids in subCaches:
0757 for id in ids:
0758 inst = subCache.tryGet(id)
0759 if inst is not None:
0760 inst.expire()
0761 self._makeObsolete()
0762
0763 def __getattr__(self, attr):
0764 """
0765 If nothing else works, let the parent connection handle it.
0766 Except with this transaction as 'self'. Poor man's
0767 acquisition? Bad programming? Okay, maybe.
0768 """
0769 self.assertActive()
0770 attr = getattr(self._dbConnection, attr)
0771 try:
0772 func = attr.im_func
0773 except AttributeError:
0774 if isinstance(attr, ConnWrapper):
0775 return ConnWrapper(attr._soClass, self)
0776 else:
0777 return attr
0778 else:
0779 meth = new.instancemethod(func, self, self.__class__)
0780 return meth
0781
0782 def _makeObsolete(self):
0783 self._obsolete = True
0784 if self._dbConnection.autoCommit:
0785 self._dbConnection._setAutoCommit(self._connection, 1)
0786 self._dbConnection.releaseConnection(self._connection,
0787 explicit=True)
0788 self._connection = None
0789 self._deletedCache = {}
0790
0791 def begin(self):
0792
0793
0794 assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0795 self._obsolete = False
0796 self._connection = self._dbConnection.getConnection()
0797 self._dbConnection._setAutoCommit(self._connection, 0)
0798
0799 def __del__(self):
0800 if self._obsolete:
0801 return
0802 self.rollback()
0803
0804 def close(self):
0805 raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0806
0807class ConnectionHub(object):
0808
0809 """
0810 This object serves as a hub for connections, so that you can pass
0811 in a ConnectionHub to a SQLObject subclass as though it was a
0812 connection, but actually bind a real database connection later.
0813 You can also bind connections on a per-thread basis.
0814
0815 You must hang onto the original ConnectionHub instance, as you
0816 cannot retrieve it again from the class or instance.
0817
0818 To use the hub, do something like::
0819
0820 hub = ConnectionHub()
0821 class MyClass(SQLObject):
0822 _connection = hub
0823
0824 hub.threadConnection = connectionFromURI('...')
0825
0826 """
0827
0828 def __init__(self):
0829 self.threadingLocal = threading_local()
0830
0831 def __get__(self, obj, type=None):
0832
0833
0834
0835 if (obj is not None) and obj.__dict__.has_key('_connection'):
0836 return obj.__dict__['_connection']
0837 return self.getConnection()
0838
0839 def __set__(self, obj, value):
0840 obj.__dict__['_connection'] = value
0841
0842 def getConnection(self):
0843 try:
0844 return self.threadingLocal.connection
0845 except AttributeError:
0846 try:
0847 return self.processConnection
0848 except AttributeError:
0849 raise AttributeError(
0850 "No connection has been defined for this thread "
0851 "or process")
0852
0853 def doInTransaction(self, func, *args, **kw):
0854 """
0855 This routine can be used to run a function in a transaction,
0856 rolling the transaction back if any exception is raised from
0857 that function, and committing otherwise.
0858
0859 Use like::
0860
0861 sqlhub.doInTransaction(process_request, os.environ)
0862
0863 This will run ``process_request(os.environ)``. The return
0864 value will be preserved.
0865 """
0866
0867
0868 try:
0869 old_conn = self.threadingLocal.connection
0870 old_conn_is_threading = True
0871 except AttributeError:
0872 old_conn = self.processConnection
0873 old_conn_is_threading = False
0874 conn = old_conn.transaction()
0875 if old_conn_is_threading:
0876 self.threadConnection = conn
0877 else:
0878 self.processConnection = conn
0879 try:
0880 try:
0881 value = func(*args, **kw)
0882 except:
0883 conn.rollback()
0884 raise
0885 else:
0886 conn.commit(close=True)
0887 return value
0888 finally:
0889 if old_conn_is_threading:
0890 self.threadConnection = old_conn
0891 else:
0892 self.processConnection = old_conn
0893
0894 def _set_threadConnection(self, value):
0895 self.threadingLocal.connection = value
0896
0897 def _get_threadConnection(self):
0898 return self.threadingLocal.connection
0899
0900 def _del_threadConnection(self):
0901 del self.threadingLocal.connection
0902
0903 threadConnection = property(_get_threadConnection,
0904 _set_threadConnection,
0905 _del_threadConnection)
0906
0907class ConnectionURIOpener(object):
0908
0909 def __init__(self):
0910 self.schemeBuilders = {}
0911 self.instanceNames = {}
0912 self.cachedURIs = {}
0913
0914 def registerConnection(self, schemes, builder):
0915 for uriScheme in schemes:
0916 assert not self.schemeBuilders.has_key(uriScheme) or self.schemeBuilders[uriScheme] is builder, "A driver has already been registered for the URI scheme %s" % uriScheme
0919 self.schemeBuilders[uriScheme] = builder
0920
0921 def registerConnectionInstance(self, inst):
0922 if inst.name:
0923 assert not self.instanceNames.has_key(inst.name) or self.instanceNames[inst.name] is cls, "A instance has already been registered with the name %s" % inst.name
0926 assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0927 self.instanceNames[inst.name] = inst
0928
0929 def connectionForURI(self, uri, **args):
0930 if args:
0931 if '?' not in uri:
0932 uri += '?' + urllib.urlencode(args)
0933 else:
0934 uri += '&' + urllib.urlencode(args)
0935 if self.cachedURIs.has_key(uri):
0936 return self.cachedURIs[uri]
0937 if uri.find(':') != -1:
0938 scheme, rest = uri.split(':', 1)
0939 connCls = self.dbConnectionForScheme(scheme)
0940 conn = connCls.connectionFromURI(uri)
0941 else:
0942
0943 assert self.instanceNames.has_key(uri), "No SQLObject driver exists under the name %s" % uri
0945 conn = self.instanceNames[uri]
0946
0947 self.cachedURIs[uri] = conn
0948 return conn
0949
0950 def dbConnectionForScheme(self, scheme):
0951 assert self.schemeBuilders.has_key(scheme), (
0952 "No SQLObject driver exists for %s (only %s)"
0953 % (scheme, ', '.join(self.schemeBuilders.keys())))
0954 return self.schemeBuilders[scheme]()
0955
0956TheURIOpener = ConnectionURIOpener()
0957
0958registerConnection = TheURIOpener.registerConnection
0959registerConnectionInstance = TheURIOpener.registerConnectionInstance
0960connectionForURI = TheURIOpener.connectionForURI
0961dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
0962
0963
0964import firebird
0965import maxdb
0966import mssql
0967import mysql
0968import postgres
0969import rdbhost
0970import sqlite
0971import sybase