0001import threading
0002from util.threadinglocal import local as threading_local
0003import sys
0004import re
0005import warnings
0006import atexit
0007import os
0008import new
0009import types
0010import urllib
0011import weakref
0012import inspect
0013import sqlbuilder
0014from cache import CacheSet
0015import col
0016import main
0017from joins import sorter
0018from converters import sqlrepr
0019import classregistry
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026    conn = ref()
0027    if conn is not None:
0028        conn.close()
0029
0030class ConsoleWriter:
0031    def __init__(self, connection, loglevel):
0032        # loglevel: None or empty string for stdout; or 'stderr'
0033        self.loglevel = loglevel or "stdout"
0034        self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035    def write(self, text):
0036        logfile = getattr(sys, self.loglevel)
0037        if isinstance(text, unicode):
0038            try:
0039                text = text.encode(self.dbEncoding)
0040            except UnicodeEncodeError:
0041                text = repr(text)[2:-1] # Remove u'...' from the repr
0042        logfile.write(text + '\n')
0043
0044class LogWriter:
0045    def __init__(self, connection, logger, loglevel):
0046        self.logger = logger
0047        self.loglevel = loglevel
0048        self.logmethod = getattr(logger, loglevel)
0049    def write(self, text):
0050        self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053    if not loggerName:
0054        return ConsoleWriter(connection, loglevel)
0055    import logging
0056    logger = logging.getLogger(loggerName)
0057    return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060    """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061    _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062                 '0': False, 'no': False, 'false': False, 'off': False}
0063    def __new__(cls, value):
0064        try:
0065            return Boolean._keywords[value.lower()]
0066        except (AttributeError, KeyError):
0067            return bool(value)
0068
0069class DBConnection:
0070
0071    def __init__(self, name=None, debug=False, debugOutput=False,
0072                 cache=True, style=None, autoCommit=True,
0073                 debugThreading=False, registry=None,
0074                 logger=None, loglevel=None):
0075        self.name = name
0076        self.debug = Boolean(debug)
0077        self.debugOutput = Boolean(debugOutput)
0078        self.debugThreading = Boolean(debugThreading)
0079        self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080        self.doCache = Boolean(cache)
0081        self.cache = CacheSet(cache=self.doCache)
0082        self.style = style
0083        self._connectionNumbers = {}
0084        self._connectionCount = 1
0085        self.autoCommit = Boolean(autoCommit)
0086        self.registry = registry or None
0087        classregistry.registry(self.registry).addCallback(
0088            self.soClassAdded)
0089        registerConnectionInstance(self)
0090        atexit.register(_closeConnection, weakref.ref(self))
0091
0092    def uri(self):
0093        auth = getattr(self, 'user', '')
0094        if auth:
0095            if self.password:
0096                auth = auth + ':' + self.password
0097            auth = auth + '@'
0098        else:
0099            assert not getattr(self, 'password', None), (
0100                'URIs cannot express passwords without usernames')
0101        uri = '%s://%s' % (self.dbName, auth)
0102        if self.host:
0103            uri += self.host
0104        uri += '/'
0105        db = self.db
0106        if db.startswith('/'):
0107            db = path[1:]
0108        return uri + db
0109
0110    def connectionFromURI(cls, uri):
0111        raise NotImplemented
0112    connectionFromURI = classmethod(connectionFromURI)
0113
0114    def _parseURI(uri):
0115        schema, rest = uri.split(':', 1)
0116        assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0117        if rest.startswith('/') and not rest.startswith('//'):
0118            host = None
0119            rest = rest[1:]
0120        elif rest.startswith('///'):
0121            host = None
0122            rest = rest[3:]
0123        else:
0124            rest = rest[2:]
0125            if rest.find('/') == -1:
0126                host = rest
0127                rest = ''
0128            else:
0129                host, rest = rest.split('/', 1)
0130        if host and host.find('@') != -1:
0131            user, host = host.rsplit('@', 1)
0132            if user.find(':') != -1:
0133                user, password = user.split(':', 1)
0134            else:
0135                password = None
0136        else:
0137            user = password = None
0138        if host and host.find(':') != -1:
0139            _host, port = host.split(':')
0140            try:
0141                port = int(port)
0142            except ValueError:
0143                raise ValueError, "port must be integer, got '%s' instead" % port
0144            if not (1 <= port <= 65535):
0145                raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0146            host = _host
0147        else:
0148            port = None
0149        path = '/' + rest
0150        if os.name == 'nt':
0151            if (len(rest) > 1) and (rest[1] == '|'):
0152                path = "%s:%s" % (rest[0], rest[2:])
0153        args = {}
0154        if path.find('?') != -1:
0155            path, arglist = path.split('?', 1)
0156            arglist = arglist.split('&')
0157            for single in arglist:
0158                argname, argvalue = single.split('=', 1)
0159                argvalue = urllib.unquote(argvalue)
0160                args[argname] = argvalue
0161        return user, password, host, port, path, args
0162    _parseURI = staticmethod(_parseURI)
0163
0164    def soClassAdded(self, soClass):
0165        """
0166        This is called for each new class; we use this opportunity
0167        to create an instance method that is bound to the class
0168        and this connection.
0169        """
0170        name = soClass.__name__
0171        assert not hasattr(self, name), (
0172            "Connection %r already has an attribute with the name "
0173            "%r (and you just created the conflicting class %r)"
0174            % (self, name, soClass))
0175        setattr(self, name, ConnWrapper(soClass, self))
0176
0177    def expireAll(self):
0178        """
0179        Expire all instances of objects for this connection.
0180        """
0181        cache_set = self.cache
0182        cache_set.weakrefAll()
0183        for item in cache_set.getAll():
0184            item.expire()
0185
0186class ConnWrapper(object):
0187
0188    """
0189    This represents a SQLObject class that is bound to a specific
0190    connection (instances have a connection instance variable, but
0191    classes are global, so this is binds the connection variable
0192    lazily when a class method is accessed)
0193    """
0194    # @@: methods that take connection arguments should be explicitly
0195    # marked up instead of the implicit use of a connection argument
0196    # and inspect.getargspec()
0197
0198    def __init__(self, soClass, connection):
0199        self._soClass = soClass
0200        self._connection = connection
0201
0202    def __call__(self, *args, **kw):
0203        kw['connection'] = self._connection
0204        return self._soClass(*args, **kw)
0205
0206    def __getattr__(self, attr):
0207        meth = getattr(self._soClass, attr)
0208        if not isinstance(meth, types.MethodType):
0209            # We don't need to wrap non-methods
0210            return meth
0211        try:
0212            takes_conn = meth.takes_connection
0213        except AttributeError:
0214            args, varargs, varkw, defaults = inspect.getargspec(meth)
0215            assert not varkw and not varargs, (
0216                "I cannot tell whether I must wrap this method, "
0217                "because it takes **kw: %r"
0218                % meth)
0219            takes_conn = 'connection' in args
0220            meth.im_func.takes_connection = takes_conn
0221        if not takes_conn:
0222            return meth
0223        return ConnMethodWrapper(meth, self._connection)
0224
0225class ConnMethodWrapper(object):
0226
0227    def __init__(self, method, connection):
0228        self._method = method
0229        self._connection = connection
0230
0231    def __getattr__(self, attr):
0232        return getattr(self._method, attr)
0233
0234    def __call__(self, *args, **kw):
0235        kw['connection'] = self._connection
0236        return self._method(*args, **kw)
0237
0238    def __repr__(self):
0239        return '<Wrapped %r with connection %r>' % (
0240            self._method, self._connection)
0241
0242class DBAPI(DBConnection):
0243
0244    """
0245    Subclass must define a `makeConnection()` method, which
0246    returns a newly-created connection object.
0247
0248    ``queryInsertID`` must also be defined.
0249    """
0250
0251    dbName = None
0252
0253    def __init__(self, **kw):
0254        self._pool = []
0255        self._poolLock = threading.Lock()
0256        DBConnection.__init__(self, **kw)
0257        self._binaryType = type(self.module.Binary(''))
0258
0259    def _runWithConnection(self, meth, *args):
0260        conn = self.getConnection()
0261        try:
0262            val = meth(conn, *args)
0263        finally:
0264            self.releaseConnection(conn)
0265        return val
0266
0267    def getConnection(self):
0268        self._poolLock.acquire()
0269        try:
0270            if not self._pool:
0271                conn = self.makeConnection()
0272                self._connectionNumbers[id(conn)] = self._connectionCount
0273                self._connectionCount += 1
0274            else:
0275                conn = self._pool.pop()
0276            if self.debug:
0277                s = 'ACQUIRE'
0278                if self._pool is not None:
0279                    s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0280                self.printDebug(conn, s, 'Pool')
0281            return conn
0282        finally:
0283            self._poolLock.release()
0284
0285    def releaseConnection(self, conn, explicit=False):
0286        if self.debug:
0287            if explicit:
0288                s = 'RELEASE (explicit)'
0289            else:
0290                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0291            if self._pool is None:
0292                s += ' no pooling'
0293            else:
0294                s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0295            self.printDebug(conn, s, 'Pool')
0296        if self.supportTransactions and not explicit:
0297            if self.autoCommit == 'exception':
0298                if self.debug:
0299                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0300                conn.rollback()
0301                raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0302            elif self.autoCommit:
0303                if self.debug:
0304                    self.printDebug(conn, 'auto', 'COMMIT')
0305                if not getattr(conn, 'autocommit', False):
0306                    conn.commit()
0307            else:
0308                if self.debug:
0309                    self.printDebug(conn, 'auto', 'ROLLBACK')
0310                conn.rollback()
0311        if self._pool is not None:
0312            if conn not in self._pool:
0313                # @@: We can get duplicate releasing of connections with
0314                # the __del__ in Iteration (unfortunately, not sure why
0315                # it happens)
0316                self._pool.insert(0, conn)
0317        else:
0318            conn.close()
0319
0320    def printDebug(self, conn, s, name, type='query'):
0321        if name == 'Pool' and self.debug != 'Pool':
0322            return
0323        if type == 'query':
0324            sep = ': '
0325        else:
0326            sep = '->'
0327            s = repr(s)
0328        n = self._connectionNumbers[id(conn)]
0329        spaces = ' '*(8-len(name))
0330        if self.debugThreading:
0331            threadName = threading.currentThread().getName()
0332            threadName = (':' + threadName + ' '*(8-len(threadName)))
0333        else:
0334            threadName = ''
0335        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0336        self.debugWriter.write(msg)
0337
0338    def _executeRetry(self, conn, cursor, query):
0339        if self.debug:
0340            self.printDebug(conn, query, 'QueryR')
0341        return cursor.execute(query)
0342
0343    def _query(self, conn, s):
0344        if self.debug:
0345            self.printDebug(conn, s, 'Query')
0346        self._executeRetry(conn, conn.cursor(), s)
0347
0348    def query(self, s):
0349        return self._runWithConnection(self._query, s)
0350
0351    def _queryAll(self, conn, s):
0352        if self.debug:
0353            self.printDebug(conn, s, 'QueryAll')
0354        c = conn.cursor()
0355        self._executeRetry(conn, c, s)
0356        value = c.fetchall()
0357        if self.debugOutput:
0358            self.printDebug(conn, value, 'QueryAll', 'result')
0359        return value
0360
0361    def queryAll(self, s):
0362        return self._runWithConnection(self._queryAll, s)
0363
0364    def _queryAllDescription(self, conn, s):
0365        """
0366        Like queryAll, but returns (description, rows), where the
0367        description is cursor.description (which gives row types)
0368        """
0369        if self.debug:
0370            self.printDebug(conn, s, 'QueryAllDesc')
0371        c = conn.cursor()
0372        self._executeRetry(conn, c, s)
0373        value = c.fetchall()
0374        if self.debugOutput:
0375            self.printDebug(conn, value, 'QueryAll', 'result')
0376        return c.description, value
0377
0378    def queryAllDescription(self, s):
0379        return self._runWithConnection(self._queryAllDescription, s)
0380
0381    def _queryOne(self, conn, s):
0382        if self.debug:
0383            self.printDebug(conn, s, 'QueryOne')
0384        c = conn.cursor()
0385        self._executeRetry(conn, c, s)
0386        value = c.fetchone()
0387        if self.debugOutput:
0388            self.printDebug(conn, value, 'QueryOne', 'result')
0389        return value
0390
0391    def queryOne(self, s):
0392        return self._runWithConnection(self._queryOne, s)
0393
0394    def _insertSQL(self, table, names, values):
0395        return ("INSERT INTO %s (%s) VALUES (%s)" %
0396                (table, ', '.join(names),
0397                 ', '.join([self.sqlrepr(v) for v in values])))
0398
0399    def transaction(self):
0400        return Transaction(self)
0401
0402    def queryInsertID(self, soInstance, id, names, values):
0403        return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0404
0405    def iterSelect(self, select):
0406        return select.IterationClass(self, self.getConnection(),
0407                         select, keepConnection=False)
0408
0409    def accumulateSelect(self, select, *expressions):
0410        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0411            to the select object.
0412        """
0413        q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0414        q = self.sqlrepr(q)
0415        val = self.queryOne(q)
0416        if len(expressions) == 1:
0417            val = val[0]
0418        return val
0419
0420    def queryForSelect(self, select):
0421        return self.sqlrepr(select.queryForSelect())
0422
0423    def _SO_createJoinTable(self, join):
0424        self.query(self._SO_createJoinTableSQL(join))
0425
0426    def _SO_createJoinTableSQL(self, join):
0427        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0428                (join.intermediateTable,
0429                 join.joinColumn,
0430                 self.joinSQLType(join),
0431                 join.otherColumn,
0432                 self.joinSQLType(join)))
0433
0434    def _SO_dropJoinTable(self, join):
0435        self.query("DROP TABLE %s" % join.intermediateTable)
0436
0437    def _SO_createIndex(self, soClass, index):
0438        self.query(self.createIndexSQL(soClass, index))
0439
0440    def createIndexSQL(self, soClass, index):
0441        assert 0, 'Implement in subclasses'
0442
0443    def createTable(self, soClass):
0444        createSql, constraints = self.createTableSQL(soClass)
0445        self.query(createSql)
0446
0447        return constraints
0448
0449    def createReferenceConstraints(self, soClass):
0450        refConstraints = [self.createReferenceConstraint(soClass, column)                             for column in soClass.sqlmeta.columnList                             if isinstance(column, col.SOForeignKey)]
0453        refConstraintDefs = [constraint                                for constraint in refConstraints                                if constraint]
0456        return refConstraintDefs
0457
0458    def createSQL(self, soClass):
0459        tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0460        if tableCreateSQLs:
0461            assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0462                '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0463                (soClass.__name__))
0464            if isinstance(tableCreateSQLs, dict):
0465                tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0466            if isinstance(tableCreateSQLs, str):
0467                tableCreateSQLs = [tableCreateSQLs]
0468            if isinstance(tableCreateSQLs, tuple):
0469                tableCreateSQLs = list(tableCreateSQLs)
0470            assert isinstance(tableCreateSQLs,list), (
0471                'Unable to create a list from %s.sqlmeta.createSQL' %
0472                (soClass.__name__))
0473        return tableCreateSQLs or []
0474
0475    def createTableSQL(self, soClass):
0476        constraints = self.createReferenceConstraints(soClass)
0477        extraSQL = self.createSQL(soClass)
0478        createSql = ('CREATE TABLE %s (\n%s\n)' %
0479                (soClass.sqlmeta.table, self.createColumns(soClass)))
0480        return createSql, constraints + extraSQL
0481
0482    def createColumns(self, soClass):
0483        columnDefs = [self.createIDColumn(soClass)]                        + [self.createColumn(soClass, col)
0485                        for col in soClass.sqlmeta.columnList]
0486        return ",\n".join(["    %s" % c for c in columnDefs])
0487
0488    def createReferenceConstraint(self, soClass, col):
0489        assert 0, "Implement in subclasses"
0490
0491    def createColumn(self, soClass, col):
0492        assert 0, "Implement in subclasses"
0493
0494    def dropTable(self, tableName, cascade=False):
0495        self.query("DROP TABLE %s" % tableName)
0496
0497    def clearTable(self, tableName):
0498        # 3-03 @@: Should this have a WHERE 1 = 1 or similar
0499        # clause?  In some configurations without the WHERE clause
0500        # the query won't go through, but maybe we shouldn't override
0501        # that.
0502        self.query("DELETE FROM %s" % tableName)
0503
0504    def createBinary(self, value):
0505        """
0506        Create a binary object wrapper for the given database.
0507        """
0508        # Default is Binary() function from the connection driver.
0509        return self.module.Binary(value)
0510
0511    # The _SO_* series of methods are sorts of "friend" methods
0512    # with SQLObject.  They grab values from the SQLObject instances
0513    # or classes freely, but keep the SQLObject class from accessing
0514    # the database directly.  This way no SQL is actually created
0515    # in the SQLObject class.
0516
0517    def _SO_update(self, so, values):
0518        self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0519                   (so.sqlmeta.table,
0520                    ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0521                               for dbName, value in values]),
0522                    so.sqlmeta.idName,
0523                    self.sqlrepr(so.id)))
0524
0525    def _SO_selectOne(self, so, columnNames):
0526        return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0527
0528
0529    def _SO_selectOneAlt(self, so, columnNames, condition):
0530        if columnNames:
0531            columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0532        else:
0533            columns = None
0534        return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0535                                                            staticTables=[so.sqlmeta.table],
0536                                                            clause=condition)))
0537
0538    def _SO_delete(self, so):
0539        self.query("DELETE FROM %s WHERE %s = (%s)" %
0540                   (so.sqlmeta.table,
0541                    so.sqlmeta.idName,
0542                    self.sqlrepr(so.id)))
0543
0544    def _SO_selectJoin(self, soClass, column, value):
0545        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0546                             (soClass.sqlmeta.idName,
0547                              soClass.sqlmeta.table,
0548                              column,
0549                              self.sqlrepr(value)))
0550
0551    def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0552        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0553                             (getColumn,
0554                              table,
0555                              joinColumn,
0556                              self.sqlrepr(value)))
0557
0558    def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0559                               secondColumn, secondValue):
0560        self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0561                   (table,
0562                    firstColumn,
0563                    self.sqlrepr(firstValue),
0564                    secondColumn,
0565                    self.sqlrepr(secondValue)))
0566
0567    def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0568                               secondColumn, secondValue):
0569        self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0570                   (table,
0571                    firstColumn,
0572                    secondColumn,
0573                    self.sqlrepr(firstValue),
0574                    self.sqlrepr(secondValue)))
0575
0576    def _SO_columnClause(self, soClass, kw):
0577        ops = {None: "IS"}
0578        data = {}
0579        if 'id' in kw:
0580            data[soClass.sqlmeta.idName] = kw.pop('id')
0581        for key, col in soClass.sqlmeta.columns.items():
0582            if key in kw:
0583                value = kw.pop(key)
0584                if col.from_python:
0585                    value = col.from_python(value, sqlbuilder.SQLObjectState(soClass))
0586                data[col.dbName] = value
0587            elif col.foreignName in kw:
0588                obj = kw.pop(col.foreignName)
0589                if isinstance(obj, main.SQLObject):
0590                    data[col.dbName] = obj.id
0591                else:
0592                    data[col.dbName] = obj
0593        if kw:
0594            # pick the first key from kw to use to raise the error,
0595            raise TypeError, "got an unexpected keyword argument(s): %r" % kw.keys()
0596
0597        if not data:
0598            return None
0599        return ' AND '.join(
0600            ['%s %s %s' %
0601             (dbName, ops.get(value, "="), self.sqlrepr(value))
0602             for dbName, value
0603             in data.items()])
0604
0605    def sqlrepr(self, v):
0606        return sqlrepr(v, self.dbName)
0607
0608    def __del__(self):
0609        self.close()
0610
0611    def close(self):
0612        if not hasattr(self, '_pool'):
0613            # Probably there was an exception while creating this
0614            # instance, so it is incomplete.
0615            return
0616        if not self._pool:
0617            return
0618        self._poolLock.acquire()
0619        try:
0620            conns = self._pool[:]
0621            self._pool[:] = []
0622            for conn in conns:
0623                try:
0624                    conn.close()
0625                except self.module.Error:
0626                    pass
0627            del conn
0628            del conns
0629        finally:
0630            self._poolLock.release()
0631
0632    def createEmptyDatabase(self):
0633        """
0634        Create an empty database.
0635        """
0636        raise NotImplementedError
0637
0638class Iteration(object):
0639
0640    def __init__(self, dbconn, rawconn, select, keepConnection=False):
0641        self.dbconn = dbconn
0642        self.rawconn = rawconn
0643        self.select = select
0644        self.keepConnection = keepConnection
0645        self.cursor = rawconn.cursor()
0646        self.query = self.dbconn.queryForSelect(select)
0647        if dbconn.debug:
0648            dbconn.printDebug(rawconn, self.query, 'Select')
0649        self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0650
0651    def __iter__(self):
0652        return self
0653
0654    def next(self):
0655        result = self.cursor.fetchone()
0656        if result is None:
0657            self._cleanup()
0658            raise StopIteration
0659        if result[0] is None:
0660            return None
0661        if self.select.ops.get('lazyColumns', 0):
0662            obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0663            return obj
0664        else:
0665            obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0666            return obj
0667
0668    def _cleanup(self):
0669        if getattr(self, 'query', None) is None:
0670            # already cleaned up
0671            return
0672        self.query = None
0673        if not self.keepConnection:
0674            self.dbconn.releaseConnection(self.rawconn)
0675        self.dbconn = self.rawconn = self.select = self.cursor = None
0676
0677    def __del__(self):
0678        self._cleanup()
0679
0680class Transaction(object):
0681
0682    def __init__(self, dbConnection):
0683        # this is to skip __del__ in case of an exception in this __init__
0684        self._obsolete = True
0685        self._dbConnection = dbConnection
0686        self._connection = dbConnection.getConnection()
0687        self._dbConnection._setAutoCommit(self._connection, 0)
0688        self.cache = CacheSet(cache=dbConnection.doCache)
0689        self._deletedCache = {}
0690        self._obsolete = False
0691
0692    def assertActive(self):
0693        assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0694
0695    def query(self, s):
0696        self.assertActive()
0697        return self._dbConnection._query(self._connection, s)
0698
0699    def queryAll(self, s):
0700        self.assertActive()
0701        return self._dbConnection._queryAll(self._connection, s)
0702
0703    def queryOne(self, s):
0704        self.assertActive()
0705        return self._dbConnection._queryOne(self._connection, s)
0706
0707    def queryInsertID(self, soInstance, id, names, values):
0708        self.assertActive()
0709        return self._dbConnection._queryInsertID(
0710            self._connection, soInstance, id, names, values)
0711
0712    def iterSelect(self, select):
0713        self.assertActive()
0714        # We can't keep the cursor open with results in a transaction,
0715        # because we might want to use the connection while we're
0716        # still iterating through the results.
0717        # @@: But would it be okay for psycopg, with threadsafety
0718        # level 2?
0719        return iter(list(select.IterationClass(self, self._connection,
0720                                   select, keepConnection=True)))
0721
0722    def _SO_delete(self, inst):
0723        cls = inst.__class__.__name__
0724        if not self._deletedCache.has_key(cls):
0725            self._deletedCache[cls] = []
0726        self._deletedCache[cls].append(inst.id)
0727        meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0728        return meth(inst)
0729
0730    def commit(self, close=False):
0731        if self._obsolete:
0732            # @@: is it okay to get extraneous commits?
0733            return
0734        if self._dbConnection.debug:
0735            self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0736        self._connection.commit()
0737        subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0738        subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0739        for cls, ids in subCaches:
0740            for id in ids:
0741                inst = self._dbConnection.cache.tryGetByName(id, cls)
0742                if inst is not None:
0743                    inst.expire()
0744        if close:
0745            self._makeObsolete()
0746
0747    def rollback(self):
0748        if self._obsolete:
0749            # @@: is it okay to get extraneous rollbacks?
0750            return
0751        if self._dbConnection.debug:
0752            self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0753        subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0754        self._connection.rollback()
0755
0756        for subCache, ids in subCaches:
0757            for id in ids:
0758                inst = subCache.tryGet(id)
0759                if inst is not None:
0760                    inst.expire()
0761        self._makeObsolete()
0762
0763    def __getattr__(self, attr):
0764        """
0765        If nothing else works, let the parent connection handle it.
0766        Except with this transaction as 'self'.  Poor man's
0767        acquisition?  Bad programming?  Okay, maybe.
0768        """
0769        self.assertActive()
0770        attr = getattr(self._dbConnection, attr)
0771        try:
0772            func = attr.im_func
0773        except AttributeError:
0774            if isinstance(attr, ConnWrapper):
0775                return ConnWrapper(attr._soClass, self)
0776            else:
0777                return attr
0778        else:
0779            meth = new.instancemethod(func, self, self.__class__)
0780            return meth
0781
0782    def _makeObsolete(self):
0783        self._obsolete = True
0784        if self._dbConnection.autoCommit:
0785            self._dbConnection._setAutoCommit(self._connection, 1)
0786        self._dbConnection.releaseConnection(self._connection,
0787                                             explicit=True)
0788        self._connection = None
0789        self._deletedCache = {}
0790
0791    def begin(self):
0792        # @@: Should we do this, or should begin() be a no-op when we're
0793        # not already obsolete?
0794        assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0795        self._obsolete = False
0796        self._connection = self._dbConnection.getConnection()
0797        self._dbConnection._setAutoCommit(self._connection, 0)
0798
0799    def __del__(self):
0800        if self._obsolete:
0801            return
0802        self.rollback()
0803
0804    def close(self):
0805        raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0806
0807class ConnectionHub(object):
0808
0809    """
0810    This object serves as a hub for connections, so that you can pass
0811    in a ConnectionHub to a SQLObject subclass as though it was a
0812    connection, but actually bind a real database connection later.
0813    You can also bind connections on a per-thread basis.
0814
0815    You must hang onto the original ConnectionHub instance, as you
0816    cannot retrieve it again from the class or instance.
0817
0818    To use the hub, do something like::
0819
0820        hub = ConnectionHub()
0821        class MyClass(SQLObject):
0822            _connection = hub
0823
0824        hub.threadConnection = connectionFromURI('...')
0825
0826    """
0827
0828    def __init__(self):
0829        self.threadingLocal = threading_local()
0830
0831    def __get__(self, obj, type=None):
0832        # I'm a little surprised we have to do this, but apparently
0833        # the object's private dictionary of attributes doesn't
0834        # override this descriptor.
0835        if (obj is not None) and obj.__dict__.has_key('_connection'):
0836            return obj.__dict__['_connection']
0837        return self.getConnection()
0838
0839    def __set__(self, obj, value):
0840        obj.__dict__['_connection'] = value
0841
0842    def getConnection(self):
0843        try:
0844            return self.threadingLocal.connection
0845        except AttributeError:
0846            try:
0847                return self.processConnection
0848            except AttributeError:
0849                raise AttributeError(
0850                    "No connection has been defined for this thread "
0851                    "or process")
0852
0853    def doInTransaction(self, func, *args, **kw):
0854        """
0855        This routine can be used to run a function in a transaction,
0856        rolling the transaction back if any exception is raised from
0857        that function, and committing otherwise.
0858
0859        Use like::
0860
0861            sqlhub.doInTransaction(process_request, os.environ)
0862
0863        This will run ``process_request(os.environ)``.  The return
0864        value will be preserved.
0865        """
0866        # @@: In Python 2.5, something usable with with: should also
0867        # be added.
0868        try:
0869            old_conn = self.threadingLocal.connection
0870            old_conn_is_threading = True
0871        except AttributeError:
0872            old_conn = self.processConnection
0873            old_conn_is_threading = False
0874        conn = old_conn.transaction()
0875        if old_conn_is_threading:
0876            self.threadConnection = conn
0877        else:
0878            self.processConnection = conn
0879        try:
0880            try:
0881                value = func(*args, **kw)
0882            except:
0883                conn.rollback()
0884                raise
0885            else:
0886                conn.commit(close=True)
0887                return value
0888        finally:
0889            if old_conn_is_threading:
0890                self.threadConnection = old_conn
0891            else:
0892                self.processConnection = old_conn
0893
0894    def _set_threadConnection(self, value):
0895        self.threadingLocal.connection = value
0896
0897    def _get_threadConnection(self):
0898        return self.threadingLocal.connection
0899
0900    def _del_threadConnection(self):
0901        del self.threadingLocal.connection
0902
0903    threadConnection = property(_get_threadConnection,
0904                                _set_threadConnection,
0905                                _del_threadConnection)
0906
0907class ConnectionURIOpener(object):
0908
0909    def __init__(self):
0910        self.schemeBuilders = {}
0911        self.instanceNames = {}
0912        self.cachedURIs = {}
0913
0914    def registerConnection(self, schemes, builder):
0915        for uriScheme in schemes:
0916            assert not self.schemeBuilders.has_key(uriScheme)                      or self.schemeBuilders[uriScheme] is builder,                      "A driver has already been registered for the URI scheme %s" % uriScheme
0919            self.schemeBuilders[uriScheme] = builder
0920
0921    def registerConnectionInstance(self, inst):
0922        if inst.name:
0923            assert not self.instanceNames.has_key(inst.name)                      or self.instanceNames[inst.name] is cls,                      "A instance has already been registered with the name %s" % inst.name
0926            assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0927            self.instanceNames[inst.name] = inst
0928
0929    def connectionForURI(self, uri, **args):
0930        if args:
0931            if '?' not in uri:
0932                uri += '?' + urllib.urlencode(args)
0933            else:
0934                uri += '&' + urllib.urlencode(args)
0935        if self.cachedURIs.has_key(uri):
0936            return self.cachedURIs[uri]
0937        if uri.find(':') != -1:
0938            scheme, rest = uri.split(':', 1)
0939            connCls = self.dbConnectionForScheme(scheme)
0940            conn = connCls.connectionFromURI(uri)
0941        else:
0942            # We just have a name, not a URI
0943            assert self.instanceNames.has_key(uri),                      "No SQLObject driver exists under the name %s" % uri
0945            conn = self.instanceNames[uri]
0946        # @@: Do we care if we clobber another connection?
0947        self.cachedURIs[uri] = conn
0948        return conn
0949
0950    def dbConnectionForScheme(self, scheme):
0951        assert self.schemeBuilders.has_key(scheme), (
0952               "No SQLObject driver exists for %s (only %s)"
0953               % (scheme, ', '.join(self.schemeBuilders.keys())))
0954        return self.schemeBuilders[scheme]()
0955
0956TheURIOpener = ConnectionURIOpener()
0957
0958registerConnection = TheURIOpener.registerConnection
0959registerConnectionInstance = TheURIOpener.registerConnectionInstance
0960connectionForURI = TheURIOpener.connectionForURI
0961dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
0962
0963# Register DB URI schemas
0964import firebird
0965import maxdb
0966import mssql
0967import mysql
0968import postgres
0969import rdbhost
0970import sqlite
0971import sybase