0001import atexit
0002from cgi import parse_qsl
0003import inspect
0004import new
0005import os
0006import sys
0007import threading
0008import types
0009import urllib
0010import warnings
0011import weakref
0012
0013from cache import CacheSet
0014import classregistry
0015import col
0016from converters import sqlrepr
0017import main
0018import sqlbuilder
0019from util.threadinglocal import local as threading_local
0020
0021warnings.filterwarnings("ignore", "DB-API extension cursor.lastrowid used")
0022
0023_connections = {}
0024
0025def _closeConnection(ref):
0026    conn = ref()
0027    if conn is not None:
0028        conn.close()
0029
0030class ConsoleWriter:
0031    def __init__(self, connection, loglevel):
0032        # loglevel: None or empty string for stdout; or 'stderr'
0033        self.loglevel = loglevel or "stdout"
0034        self.dbEncoding = getattr(connection, "dbEncoding", None) or "ascii"
0035    def write(self, text):
0036        logfile = getattr(sys, self.loglevel)
0037        if isinstance(text, unicode):
0038            try:
0039                text = text.encode(self.dbEncoding)
0040            except UnicodeEncodeError:
0041                text = repr(text)[2:-1] # Remove u'...' from the repr
0042        logfile.write(text + '\n')
0043
0044class LogWriter:
0045    def __init__(self, connection, logger, loglevel):
0046        self.logger = logger
0047        self.loglevel = loglevel
0048        self.logmethod = getattr(logger, loglevel)
0049    def write(self, text):
0050        self.logmethod(text)
0051
0052def makeDebugWriter(connection, loggerName, loglevel):
0053    if not loggerName:
0054        return ConsoleWriter(connection, loglevel)
0055    import logging
0056    logger = logging.getLogger(loggerName)
0057    return LogWriter(connection, logger, loglevel)
0058
0059class Boolean(object):
0060    """A bool class that also understands some special string keywords (yes/no, true/false, on/off, 1/0)"""
0061    _keywords = {'1': True, 'yes': True, 'true': True, 'on': True,
0062                 '0': False, 'no': False, 'false': False, 'off': False}
0063    def __new__(cls, value):
0064        try:
0065            return Boolean._keywords[value.lower()]
0066        except (AttributeError, KeyError):
0067            return bool(value)
0068
0069class DBConnection:
0070
0071    def __init__(self, name=None, debug=False, debugOutput=False,
0072                 cache=True, style=None, autoCommit=True,
0073                 debugThreading=False, registry=None,
0074                 logger=None, loglevel=None):
0075        self.name = name
0076        self.debug = Boolean(debug)
0077        self.debugOutput = Boolean(debugOutput)
0078        self.debugThreading = Boolean(debugThreading)
0079        self.debugWriter = makeDebugWriter(self, logger, loglevel)
0080        self.doCache = Boolean(cache)
0081        self.cache = CacheSet(cache=self.doCache)
0082        self.style = style
0083        self._connectionNumbers = {}
0084        self._connectionCount = 1
0085        self.autoCommit = Boolean(autoCommit)
0086        self.registry = registry or None
0087        classregistry.registry(self.registry).addCallback(self.soClassAdded)
0088        registerConnectionInstance(self)
0089        atexit.register(_closeConnection, weakref.ref(self))
0090
0091    def oldUri(self):
0092        auth = getattr(self, 'user', '') or ''
0093        if auth:
0094            if self.password:
0095                auth = auth + ':' + self.password
0096            auth = auth + '@'
0097        else:
0098            assert not getattr(self, 'password', None), (
0099                'URIs cannot express passwords without usernames')
0100        uri = '%s://%s' % (self.dbName, auth)
0101        if self.host:
0102            uri += self.host
0103            if self.port:
0104                uri += ':%d' % self.port
0105        uri += '/'
0106        db = self.db
0107        if db.startswith('/'):
0108            db = db[1:]
0109        return uri + db
0110
0111    def uri(self):
0112        auth = getattr(self, 'user', '') or ''
0113        if auth:
0114            auth = urllib.quote(auth)
0115            if self.password:
0116                auth = auth + ':' + urllib.quote(self.password)
0117            auth = auth + '@'
0118        else:
0119            assert not getattr(self, 'password', None), (
0120                'URIs cannot express passwords without usernames')
0121        uri = '%s://%s' % (self.dbName, auth)
0122        if self.host:
0123            uri += self.host
0124            if self.port:
0125                uri += ':%d' % self.port
0126        uri += '/'
0127        db = self.db
0128        if db.startswith('/'):
0129            db = db[1:]
0130        return uri + urllib.quote(db)
0131
0132    @classmethod
0133    def connectionFromOldURI(cls, uri):
0134        return cls._connectionFromParams(*cls._parseOldURI(uri))
0135
0136    @classmethod
0137    def connectionFromURI(cls, uri):
0138        return cls._connectionFromParams(*cls._parseURI(uri))
0139
0140    @staticmethod
0141    def _parseOldURI(uri):
0142        schema, rest = uri.split(':', 1)
0143        assert rest.startswith('/'), "URIs must start with scheme:/ -- you did not include a / (in %r)" % rest
0144        if rest.startswith('/') and not rest.startswith('//'):
0145            host = None
0146            rest = rest[1:]
0147        elif rest.startswith('///'):
0148            host = None
0149            rest = rest[3:]
0150        else:
0151            rest = rest[2:]
0152            if rest.find('/') == -1:
0153                host = rest
0154                rest = ''
0155            else:
0156                host, rest = rest.split('/', 1)
0157        if host and host.find('@') != -1:
0158            user, host = host.rsplit('@', 1)
0159            if user.find(':') != -1:
0160                user, password = user.split(':', 1)
0161            else:
0162                password = None
0163        else:
0164            user = password = None
0165        if host and host.find(':') != -1:
0166            _host, port = host.split(':')
0167            try:
0168                port = int(port)
0169            except ValueError:
0170                raise ValueError, "port must be integer, got '%s' instead" % port
0171            if not (1 <= port <= 65535):
0172                raise ValueError, "port must be integer in the range 1-65535, got '%d' instead" % port
0173            host = _host
0174        else:
0175            port = None
0176        path = '/' + rest
0177        if os.name == 'nt':
0178            if (len(rest) > 1) and (rest[1] == '|'):
0179                path = "%s:%s" % (rest[0], rest[2:])
0180        args = {}
0181        if path.find('?') != -1:
0182            path, arglist = path.split('?', 1)
0183            arglist = arglist.split('&')
0184            for single in arglist:
0185                argname, argvalue = single.split('=', 1)
0186                argvalue = urllib.unquote(argvalue)
0187                args[argname] = argvalue
0188        return user, password, host, port, path, args
0189
0190    @staticmethod
0191    def _parseURI(uri):
0192        protocol, request = urllib.splittype(uri)
0193        user, password, port = None, None, None
0194        host, path = urllib.splithost(request)
0195
0196        if host:
0197            # Python < 2.7 have a problem - splituser() calls unquote() too early
0198            #user, host = urllib.splituser(host)
0199            if '@' in host:
0200                user, host = host.split('@', 1)
0201            if user:
0202                user, password = [x and urllib.unquote(x) or None for x in urllib.splitpasswd(user)]
0203            host, port = urllib.splitport(host)
0204            if port: port = int(port)
0205        elif host == '':
0206            host = None
0207
0208        # hash-tag is splitted but ignored
0209        path, tag = urllib.splittag(path)
0210        path, query = urllib.splitquery(path)
0211
0212        path = urllib.unquote(path)
0213        if (os.name == 'nt') and (len(path) > 2):
0214            # Preserve backward compatibility with URIs like /C|/path;
0215            # replace '|' by ':'
0216            if path[2] == '|':
0217                path = "%s:%s" % (path[0:2], path[3:])
0218            # Remove leading slash
0219            if (path[0] == '/') and (path[2] == ':'):
0220                path = path[1:]
0221
0222        args = {}
0223        if query:
0224            for name, value in parse_qsl(query):
0225                args[name] = value
0226
0227        return user, password, host, port, path, args
0228
0229    def soClassAdded(self, soClass):
0230        """
0231        This is called for each new class; we use this opportunity
0232        to create an instance method that is bound to the class
0233        and this connection.
0234        """
0235        name = soClass.__name__
0236        assert not hasattr(self, name), (
0237            "Connection %r already has an attribute with the name "
0238            "%r (and you just created the conflicting class %r)"
0239            % (self, name, soClass))
0240        setattr(self, name, ConnWrapper(soClass, self))
0241
0242    def expireAll(self):
0243        """
0244        Expire all instances of objects for this connection.
0245        """
0246        cache_set = self.cache
0247        cache_set.weakrefAll()
0248        for item in cache_set.getAll():
0249            item.expire()
0250
0251class ConnWrapper(object):
0252
0253    """
0254    This represents a SQLObject class that is bound to a specific
0255    connection (instances have a connection instance variable, but
0256    classes are global, so this is binds the connection variable
0257    lazily when a class method is accessed)
0258    """
0259    # @@: methods that take connection arguments should be explicitly
0260    # marked up instead of the implicit use of a connection argument
0261    # and inspect.getargspec()
0262
0263    def __init__(self, soClass, connection):
0264        self._soClass = soClass
0265        self._connection = connection
0266
0267    def __call__(self, *args, **kw):
0268        kw['connection'] = self._connection
0269        return self._soClass(*args, **kw)
0270
0271    def __getattr__(self, attr):
0272        meth = getattr(self._soClass, attr)
0273        if not isinstance(meth, types.MethodType):
0274            # We don't need to wrap non-methods
0275            return meth
0276        try:
0277            takes_conn = meth.takes_connection
0278        except AttributeError:
0279            args, varargs, varkw, defaults = inspect.getargspec(meth)
0280            assert not varkw and not varargs, (
0281                "I cannot tell whether I must wrap this method, "
0282                "because it takes **kw: %r"
0283                % meth)
0284            takes_conn = 'connection' in args
0285            meth.im_func.takes_connection = takes_conn
0286        if not takes_conn:
0287            return meth
0288        return ConnMethodWrapper(meth, self._connection)
0289
0290class ConnMethodWrapper(object):
0291
0292    def __init__(self, method, connection):
0293        self._method = method
0294        self._connection = connection
0295
0296    def __getattr__(self, attr):
0297        return getattr(self._method, attr)
0298
0299    def __call__(self, *args, **kw):
0300        kw['connection'] = self._connection
0301        return self._method(*args, **kw)
0302
0303    def __repr__(self):
0304        return '<Wrapped %r with connection %r>' % (
0305            self._method, self._connection)
0306
0307class DBAPI(DBConnection):
0308
0309    """
0310    Subclass must define a `makeConnection()` method, which
0311    returns a newly-created connection object.
0312
0313    ``queryInsertID`` must also be defined.
0314    """
0315
0316    dbName = None
0317
0318    def __init__(self, **kw):
0319        self._pool = []
0320        self._poolLock = threading.Lock()
0321        DBConnection.__init__(self, **kw)
0322        self._binaryType = type(self.module.Binary(''))
0323
0324    def _runWithConnection(self, meth, *args):
0325        conn = self.getConnection()
0326        try:
0327            val = meth(conn, *args)
0328        finally:
0329            self.releaseConnection(conn)
0330        return val
0331
0332    def getConnection(self):
0333        self._poolLock.acquire()
0334        try:
0335            if not self._pool:
0336                conn = self.makeConnection()
0337                self._connectionNumbers[id(conn)] = self._connectionCount
0338                self._connectionCount += 1
0339            else:
0340                conn = self._pool.pop()
0341            if self.debug:
0342                s = 'ACQUIRE'
0343                if self._pool is not None:
0344                    s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0345                self.printDebug(conn, s, 'Pool')
0346            return conn
0347        finally:
0348            self._poolLock.release()
0349
0350    def releaseConnection(self, conn, explicit=False):
0351        if self.debug:
0352            if explicit:
0353                s = 'RELEASE (explicit)'
0354            else:
0355                s = 'RELEASE (implicit, autocommit=%s)' % self.autoCommit
0356            if self._pool is None:
0357                s += ' no pooling'
0358            else:
0359                s += ' pool=[%s]' % ', '.join([str(self._connectionNumbers[id(v)]) for v in self._pool])
0360            self.printDebug(conn, s, 'Pool')
0361        if self.supportTransactions and not explicit:
0362            if self.autoCommit == 'exception':
0363                if self.debug:
0364                    self.printDebug(conn, 'auto/exception', 'ROLLBACK')
0365                conn.rollback()
0366                raise Exception, 'Object used outside of a transaction; implicit COMMIT or ROLLBACK not allowed'
0367            elif self.autoCommit:
0368                if self.debug:
0369                    self.printDebug(conn, 'auto', 'COMMIT')
0370                if not getattr(conn, 'autocommit', False):
0371                    conn.commit()
0372            else:
0373                if self.debug:
0374                    self.printDebug(conn, 'auto', 'ROLLBACK')
0375                conn.rollback()
0376        if self._pool is not None:
0377            if conn not in self._pool:
0378                # @@: We can get duplicate releasing of connections with
0379                # the __del__ in Iteration (unfortunately, not sure why
0380                # it happens)
0381                self._pool.insert(0, conn)
0382        else:
0383            conn.close()
0384
0385    def printDebug(self, conn, s, name, type='query'):
0386        if name == 'Pool' and self.debug != 'Pool':
0387            return
0388        if type == 'query':
0389            sep = ': '
0390        else:
0391            sep = '->'
0392            s = repr(s)
0393        n = self._connectionNumbers[id(conn)]
0394        spaces = ' '*(8-len(name))
0395        if self.debugThreading:
0396            threadName = threading.currentThread().getName()
0397            threadName = (':' + threadName + ' '*(8-len(threadName)))
0398        else:
0399            threadName = ''
0400        msg = '%(n)2i%(threadName)s/%(name)s%(spaces)s%(sep)s %(s)s' % locals()
0401        self.debugWriter.write(msg)
0402
0403    def _executeRetry(self, conn, cursor, query):
0404        if self.debug:
0405            self.printDebug(conn, query, 'QueryR')
0406        return cursor.execute(query)
0407
0408    def _query(self, conn, s):
0409        if self.debug:
0410            self.printDebug(conn, s, 'Query')
0411        self._executeRetry(conn, conn.cursor(), s)
0412
0413    def query(self, s):
0414        return self._runWithConnection(self._query, s)
0415
0416    def _queryAll(self, conn, s):
0417        if self.debug:
0418            self.printDebug(conn, s, 'QueryAll')
0419        c = conn.cursor()
0420        self._executeRetry(conn, c, s)
0421        value = c.fetchall()
0422        if self.debugOutput:
0423            self.printDebug(conn, value, 'QueryAll', 'result')
0424        return value
0425
0426    def queryAll(self, s):
0427        return self._runWithConnection(self._queryAll, s)
0428
0429    def _queryAllDescription(self, conn, s):
0430        """
0431        Like queryAll, but returns (description, rows), where the
0432        description is cursor.description (which gives row types)
0433        """
0434        if self.debug:
0435            self.printDebug(conn, s, 'QueryAllDesc')
0436        c = conn.cursor()
0437        self._executeRetry(conn, c, s)
0438        value = c.fetchall()
0439        if self.debugOutput:
0440            self.printDebug(conn, value, 'QueryAll', 'result')
0441        return c.description, value
0442
0443    def queryAllDescription(self, s):
0444        return self._runWithConnection(self._queryAllDescription, s)
0445
0446    def _queryOne(self, conn, s):
0447        if self.debug:
0448            self.printDebug(conn, s, 'QueryOne')
0449        c = conn.cursor()
0450        self._executeRetry(conn, c, s)
0451        value = c.fetchone()
0452        if self.debugOutput:
0453            self.printDebug(conn, value, 'QueryOne', 'result')
0454        return value
0455
0456    def queryOne(self, s):
0457        return self._runWithConnection(self._queryOne, s)
0458
0459    def _insertSQL(self, table, names, values):
0460        return ("INSERT INTO %s (%s) VALUES (%s)" %
0461                (table, ', '.join(names),
0462                 ', '.join([self.sqlrepr(v) for v in values])))
0463
0464    def transaction(self):
0465        return Transaction(self)
0466
0467    def queryInsertID(self, soInstance, id, names, values):
0468        return self._runWithConnection(self._queryInsertID, soInstance, id, names, values)
0469
0470    def iterSelect(self, select):
0471        return select.IterationClass(self, self.getConnection(),
0472                         select, keepConnection=False)
0473
0474    def accumulateSelect(self, select, *expressions):
0475        """ Apply an accumulate function(s) (SUM, COUNT, MIN, AVG, MAX, etc...)
0476            to the select object.
0477        """
0478        q = select.queryForSelect().newItems(expressions).unlimited().orderBy(None)
0479        q = self.sqlrepr(q)
0480        val = self.queryOne(q)
0481        if len(expressions) == 1:
0482            val = val[0]
0483        return val
0484
0485    def queryForSelect(self, select):
0486        return self.sqlrepr(select.queryForSelect())
0487
0488    def _SO_createJoinTable(self, join):
0489        self.query(self._SO_createJoinTableSQL(join))
0490
0491    def _SO_createJoinTableSQL(self, join):
0492        return ('CREATE TABLE %s (\n%s %s,\n%s %s\n)' %
0493                (join.intermediateTable,
0494                 join.joinColumn,
0495                 self.joinSQLType(join),
0496                 join.otherColumn,
0497                 self.joinSQLType(join)))
0498
0499    def _SO_dropJoinTable(self, join):
0500        self.query("DROP TABLE %s" % join.intermediateTable)
0501
0502    def _SO_createIndex(self, soClass, index):
0503        self.query(self.createIndexSQL(soClass, index))
0504
0505    def createIndexSQL(self, soClass, index):
0506        assert 0, 'Implement in subclasses'
0507
0508    def createTable(self, soClass):
0509        createSql, constraints = self.createTableSQL(soClass)
0510        self.query(createSql)
0511
0512        return constraints
0513
0514    def createReferenceConstraints(self, soClass):
0515        refConstraints = [self.createReferenceConstraint(soClass, column)                             for column in soClass.sqlmeta.columnList                             if isinstance(column, col.SOForeignKey)]
0518        refConstraintDefs = [constraint                                for constraint in refConstraints                                if constraint]
0521        return refConstraintDefs
0522
0523    def createSQL(self, soClass):
0524        tableCreateSQLs = getattr(soClass.sqlmeta, 'createSQL', None)
0525        if tableCreateSQLs:
0526            assert isinstance(tableCreateSQLs,(str,list,dict,tuple)), (
0527                '%s.sqlmeta.createSQL must be a str, list, dict or tuple.' %
0528                (soClass.__name__))
0529            if isinstance(tableCreateSQLs, dict):
0530                tableCreateSQLs = tableCreateSQLs.get(soClass._connection.dbName, [])
0531            if isinstance(tableCreateSQLs, str):
0532                tableCreateSQLs = [tableCreateSQLs]
0533            if isinstance(tableCreateSQLs, tuple):
0534                tableCreateSQLs = list(tableCreateSQLs)
0535            assert isinstance(tableCreateSQLs,list), (
0536                'Unable to create a list from %s.sqlmeta.createSQL' %
0537                (soClass.__name__))
0538        return tableCreateSQLs or []
0539
0540    def createTableSQL(self, soClass):
0541        constraints = self.createReferenceConstraints(soClass)
0542        extraSQL = self.createSQL(soClass)
0543        createSql = ('CREATE TABLE %s (\n%s\n)' %
0544                (soClass.sqlmeta.table, self.createColumns(soClass)))
0545        return createSql, constraints + extraSQL
0546
0547    def createColumns(self, soClass):
0548        columnDefs = [self.createIDColumn(soClass)]                        + [self.createColumn(soClass, col)
0550                        for col in soClass.sqlmeta.columnList]
0551        return ",\n".join(["    %s" % c for c in columnDefs])
0552
0553    def createReferenceConstraint(self, soClass, col):
0554        assert 0, "Implement in subclasses"
0555
0556    def createColumn(self, soClass, col):
0557        assert 0, "Implement in subclasses"
0558
0559    def dropTable(self, tableName, cascade=False):
0560        self.query("DROP TABLE %s" % tableName)
0561
0562    def clearTable(self, tableName):
0563        # 3-03 @@: Should this have a WHERE 1 = 1 or similar
0564        # clause?  In some configurations without the WHERE clause
0565        # the query won't go through, but maybe we shouldn't override
0566        # that.
0567        self.query("DELETE FROM %s" % tableName)
0568
0569    def createBinary(self, value):
0570        """
0571        Create a binary object wrapper for the given database.
0572        """
0573        # Default is Binary() function from the connection driver.
0574        return self.module.Binary(value)
0575
0576    # The _SO_* series of methods are sorts of "friend" methods
0577    # with SQLObject.  They grab values from the SQLObject instances
0578    # or classes freely, but keep the SQLObject class from accessing
0579    # the database directly.  This way no SQL is actually created
0580    # in the SQLObject class.
0581
0582    def _SO_update(self, so, values):
0583        self.query("UPDATE %s SET %s WHERE %s = (%s)" %
0584                   (so.sqlmeta.table,
0585                    ", ".join(["%s = (%s)" % (dbName, self.sqlrepr(value))
0586                               for dbName, value in values]),
0587                    so.sqlmeta.idName,
0588                    self.sqlrepr(so.id)))
0589
0590    def _SO_selectOne(self, so, columnNames):
0591        return self._SO_selectOneAlt(so, columnNames, so.q.id==so.id)
0592
0593
0594    def _SO_selectOneAlt(self, so, columnNames, condition):
0595        if columnNames:
0596            columns = [isinstance(x, basestring) and sqlbuilder.SQLConstant(x) or x for x in columnNames]
0597        else:
0598            columns = None
0599        return self.queryOne(self.sqlrepr(sqlbuilder.Select(columns,
0600                                                            staticTables=[so.sqlmeta.table],
0601                                                            clause=condition)))
0602
0603    def _SO_delete(self, so):
0604        self.query("DELETE FROM %s WHERE %s = (%s)" %
0605                   (so.sqlmeta.table,
0606                    so.sqlmeta.idName,
0607                    self.sqlrepr(so.id)))
0608
0609    def _SO_selectJoin(self, soClass, column, value):
0610        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0611                             (soClass.sqlmeta.idName,
0612                              soClass.sqlmeta.table,
0613                              column,
0614                              self.sqlrepr(value)))
0615
0616    def _SO_intermediateJoin(self, table, getColumn, joinColumn, value):
0617        return self.queryAll("SELECT %s FROM %s WHERE %s = (%s)" %
0618                             (getColumn,
0619                              table,
0620                              joinColumn,
0621                              self.sqlrepr(value)))
0622
0623    def _SO_intermediateDelete(self, table, firstColumn, firstValue,
0624                               secondColumn, secondValue):
0625        self.query("DELETE FROM %s WHERE %s = (%s) AND %s = (%s)" %
0626                   (table,
0627                    firstColumn,
0628                    self.sqlrepr(firstValue),
0629                    secondColumn,
0630                    self.sqlrepr(secondValue)))
0631
0632    def _SO_intermediateInsert(self, table, firstColumn, firstValue,
0633                               secondColumn, secondValue):
0634        self.query("INSERT INTO %s (%s, %s) VALUES (%s, %s)" %
0635                   (table,
0636                    firstColumn,
0637                    secondColumn,
0638                    self.sqlrepr(firstValue),
0639                    self.sqlrepr(secondValue)))
0640
0641    def _SO_columnClause(self, soClass, kw):
0642        ops = {None: "IS"}
0643        data = {}
0644        if 'id' in kw:
0645            data[soClass.sqlmeta.idName] = kw.pop('id')
0646        for key, col in soClass.sqlmeta.columns.items():
0647            if key in kw:
0648                value = kw.pop(key)
0649                if col.from_python:
0650                    value = col.from_python(value, sqlbuilder.SQLObjectState(soClass, connection=self))
0651                data[col.dbName] = value
0652            elif col.foreignName in kw:
0653                obj = kw.pop(col.foreignName)
0654                if isinstance(obj, main.SQLObject):
0655                    data[col.dbName] = obj.id
0656                else:
0657                    data[col.dbName] = obj
0658        if kw:
0659            # pick the first key from kw to use to raise the error,
0660            raise TypeError, "got an unexpected keyword argument(s): %r" % kw.keys()
0661
0662        if not data:
0663            return None
0664        return ' AND '.join(
0665            ['%s %s %s' %
0666             (dbName, ops.get(value, "="), self.sqlrepr(value))
0667             for dbName, value
0668             in data.items()])
0669
0670    def sqlrepr(self, v):
0671        return sqlrepr(v, self.dbName)
0672
0673    def __del__(self):
0674        self.close()
0675
0676    def close(self):
0677        if not hasattr(self, '_pool'):
0678            # Probably there was an exception while creating this
0679            # instance, so it is incomplete.
0680            return
0681        if not self._pool:
0682            return
0683        self._poolLock.acquire()
0684        try:
0685            if not self._pool: # _pool could be filled in a different thread
0686                return
0687            conns = self._pool[:]
0688            self._pool[:] = []
0689            for conn in conns:
0690                try:
0691                    conn.close()
0692                except self.module.Error:
0693                    pass
0694            del conn
0695            del conns
0696        finally:
0697            self._poolLock.release()
0698
0699    def createEmptyDatabase(self):
0700        """
0701        Create an empty database.
0702        """
0703        raise NotImplementedError
0704
0705class Iteration(object):
0706
0707    def __init__(self, dbconn, rawconn, select, keepConnection=False):
0708        self.dbconn = dbconn
0709        self.rawconn = rawconn
0710        self.select = select
0711        self.keepConnection = keepConnection
0712        self.cursor = rawconn.cursor()
0713        self.query = self.dbconn.queryForSelect(select)
0714        if dbconn.debug:
0715            dbconn.printDebug(rawconn, self.query, 'Select')
0716        self.dbconn._executeRetry(self.rawconn, self.cursor, self.query)
0717
0718    def __iter__(self):
0719        return self
0720
0721    def next(self):
0722        result = self.cursor.fetchone()
0723        if result is None:
0724            self._cleanup()
0725            raise StopIteration
0726        if result[0] is None:
0727            return None
0728        if self.select.ops.get('lazyColumns', 0):
0729            obj = self.select.sourceClass.get(result[0], connection=self.dbconn)
0730            return obj
0731        else:
0732            obj = self.select.sourceClass.get(result[0], selectResults=result[1:], connection=self.dbconn)
0733            return obj
0734
0735    def _cleanup(self):
0736        if getattr(self, 'query', None) is None:
0737            # already cleaned up
0738            return
0739        self.query = None
0740        if not self.keepConnection:
0741            self.dbconn.releaseConnection(self.rawconn)
0742        self.dbconn = self.rawconn = self.select = self.cursor = None
0743
0744    def __del__(self):
0745        self._cleanup()
0746
0747class Transaction(object):
0748
0749    def __init__(self, dbConnection):
0750        # this is to skip __del__ in case of an exception in this __init__
0751        self._obsolete = True
0752        self._dbConnection = dbConnection
0753        self._connection = dbConnection.getConnection()
0754        self._dbConnection._setAutoCommit(self._connection, 0)
0755        self.cache = CacheSet(cache=dbConnection.doCache)
0756        self._deletedCache = {}
0757        self._obsolete = False
0758
0759    def assertActive(self):
0760        assert not self._obsolete, "This transaction has already gone through ROLLBACK; begin another transaction"
0761
0762    def query(self, s):
0763        self.assertActive()
0764        return self._dbConnection._query(self._connection, s)
0765
0766    def queryAll(self, s):
0767        self.assertActive()
0768        return self._dbConnection._queryAll(self._connection, s)
0769
0770    def queryOne(self, s):
0771        self.assertActive()
0772        return self._dbConnection._queryOne(self._connection, s)
0773
0774    def queryInsertID(self, soInstance, id, names, values):
0775        self.assertActive()
0776        return self._dbConnection._queryInsertID(
0777            self._connection, soInstance, id, names, values)
0778
0779    def iterSelect(self, select):
0780        self.assertActive()
0781        # We can't keep the cursor open with results in a transaction,
0782        # because we might want to use the connection while we're
0783        # still iterating through the results.
0784        # @@: But would it be okay for psycopg, with threadsafety
0785        # level 2?
0786        return iter(list(select.IterationClass(self, self._connection,
0787                                   select, keepConnection=True)))
0788
0789    def _SO_delete(self, inst):
0790        cls = inst.__class__.__name__
0791        if not cls in self._deletedCache:
0792            self._deletedCache[cls] = []
0793        self._deletedCache[cls].append(inst.id)
0794        meth = new.instancemethod(self._dbConnection._SO_delete.im_func, self, self.__class__)
0795        return meth(inst)
0796
0797    def commit(self, close=False):
0798        if self._obsolete:
0799            # @@: is it okay to get extraneous commits?
0800            return
0801        if self._dbConnection.debug:
0802            self._dbConnection.printDebug(self._connection, '', 'COMMIT')
0803        self._connection.commit()
0804        subCaches = [(sub[0], sub[1].allIDs()) for sub in self.cache.allSubCachesByClassNames().items()]
0805        subCaches.extend([(x[0], x[1]) for x in self._deletedCache.items()])
0806        for cls, ids in subCaches:
0807            for id in ids:
0808                inst = self._dbConnection.cache.tryGetByName(id, cls)
0809                if inst is not None:
0810                    inst.expire()
0811        if close:
0812            self._makeObsolete()
0813
0814    def rollback(self):
0815        if self._obsolete:
0816            # @@: is it okay to get extraneous rollbacks?
0817            return
0818        if self._dbConnection.debug:
0819            self._dbConnection.printDebug(self._connection, '', 'ROLLBACK')
0820        subCaches = [(sub, sub.allIDs()) for sub in self.cache.allSubCaches()]
0821        self._connection.rollback()
0822
0823        for subCache, ids in subCaches:
0824            for id in ids:
0825                inst = subCache.tryGet(id)
0826                if inst is not None:
0827                    inst.expire()
0828        self._makeObsolete()
0829
0830    def __getattr__(self, attr):
0831        """
0832        If nothing else works, let the parent connection handle it.
0833        Except with this transaction as 'self'.  Poor man's
0834        acquisition?  Bad programming?  Okay, maybe.
0835        """
0836        self.assertActive()
0837        attr = getattr(self._dbConnection, attr)
0838        try:
0839            func = attr.im_func
0840        except AttributeError:
0841            if isinstance(attr, ConnWrapper):
0842                return ConnWrapper(attr._soClass, self)
0843            else:
0844                return attr
0845        else:
0846            meth = new.instancemethod(func, self, self.__class__)
0847            return meth
0848
0849    def _makeObsolete(self):
0850        self._obsolete = True
0851        if self._dbConnection.autoCommit:
0852            self._dbConnection._setAutoCommit(self._connection, 1)
0853        self._dbConnection.releaseConnection(self._connection,
0854                                             explicit=True)
0855        self._connection = None
0856        self._deletedCache = {}
0857
0858    def begin(self):
0859        # @@: Should we do this, or should begin() be a no-op when we're
0860        # not already obsolete?
0861        assert self._obsolete, "You cannot begin a new transaction session without rolling back this one"
0862        self._obsolete = False
0863        self._connection = self._dbConnection.getConnection()
0864        self._dbConnection._setAutoCommit(self._connection, 0)
0865
0866    def __del__(self):
0867        if self._obsolete:
0868            return
0869        self.rollback()
0870
0871    def close(self):
0872        raise TypeError('You cannot just close transaction - you should either call rollback(), commit() or commit(close=True) to close the underlying connection.')
0873
0874class ConnectionHub(object):
0875
0876    """
0877    This object serves as a hub for connections, so that you can pass
0878    in a ConnectionHub to a SQLObject subclass as though it was a
0879    connection, but actually bind a real database connection later.
0880    You can also bind connections on a per-thread basis.
0881
0882    You must hang onto the original ConnectionHub instance, as you
0883    cannot retrieve it again from the class or instance.
0884
0885    To use the hub, do something like::
0886
0887        hub = ConnectionHub()
0888        class MyClass(SQLObject):
0889            _connection = hub
0890
0891        hub.threadConnection = connectionFromURI('...')
0892
0893    """
0894
0895    def __init__(self):
0896        self.threadingLocal = threading_local()
0897
0898    def __get__(self, obj, type=None):
0899        # I'm a little surprised we have to do this, but apparently
0900        # the object's private dictionary of attributes doesn't
0901        # override this descriptor.
0902        if (obj is not None) and '_connection' in obj.__dict__:
0903            return obj.__dict__['_connection']
0904        return self.getConnection()
0905
0906    def __set__(self, obj, value):
0907        obj.__dict__['_connection'] = value
0908
0909    def getConnection(self):
0910        try:
0911            return self.threadingLocal.connection
0912        except AttributeError:
0913            try:
0914                return self.processConnection
0915            except AttributeError:
0916                raise AttributeError(
0917                    "No connection has been defined for this thread "
0918                    "or process")
0919
0920    def doInTransaction(self, func, *args, **kw):
0921        """
0922        This routine can be used to run a function in a transaction,
0923        rolling the transaction back if any exception is raised from
0924        that function, and committing otherwise.
0925
0926        Use like::
0927
0928            sqlhub.doInTransaction(process_request, os.environ)
0929
0930        This will run ``process_request(os.environ)``.  The return
0931        value will be preserved.
0932        """
0933        # @@: In Python 2.5, something usable with with: should also
0934        # be added.
0935        try:
0936            old_conn = self.threadingLocal.connection
0937            old_conn_is_threading = True
0938        except AttributeError:
0939            old_conn = self.processConnection
0940            old_conn_is_threading = False
0941        conn = old_conn.transaction()
0942        if old_conn_is_threading:
0943            self.threadConnection = conn
0944        else:
0945            self.processConnection = conn
0946        try:
0947            try:
0948                value = func(*args, **kw)
0949            except:
0950                conn.rollback()
0951                raise
0952            else:
0953                conn.commit(close=True)
0954                return value
0955        finally:
0956            if old_conn_is_threading:
0957                self.threadConnection = old_conn
0958            else:
0959                self.processConnection = old_conn
0960
0961    def _set_threadConnection(self, value):
0962        self.threadingLocal.connection = value
0963
0964    def _get_threadConnection(self):
0965        return self.threadingLocal.connection
0966
0967    def _del_threadConnection(self):
0968        del self.threadingLocal.connection
0969
0970    threadConnection = property(_get_threadConnection,
0971                                _set_threadConnection,
0972                                _del_threadConnection)
0973
0974class ConnectionURIOpener(object):
0975
0976    def __init__(self):
0977        self.schemeBuilders = {}
0978        self.instanceNames = {}
0979        self.cachedURIs = {}
0980
0981    def registerConnection(self, schemes, builder):
0982        for uriScheme in schemes:
0983            assert not uriScheme in self.schemeBuilders                      or self.schemeBuilders[uriScheme] is builder,                      "A driver has already been registered for the URI scheme %s" % uriScheme
0986            self.schemeBuilders[uriScheme] = builder
0987
0988    def registerConnectionInstance(self, inst):
0989        if inst.name:
0990            assert not inst.name in self.instanceNames                      or self.instanceNames[inst.name] is cls,                      "A instance has already been registered with the name %s" % inst.name
0993            assert inst.name.find(':') == -1, "You cannot include ':' in your class names (%r)" % cls.name
0994            self.instanceNames[inst.name] = inst
0995
0996    def connectionForURI(self, uri, oldUri=False, **args):
0997        if args:
0998            if '?' not in uri:
0999                uri += '?' + urllib.urlencode(args)
1000            else:
1001                uri += '&' + urllib.urlencode(args)
1002        if uri in self.cachedURIs:
1003            return self.cachedURIs[uri]
1004        if uri.find(':') != -1:
1005            scheme, rest = uri.split(':', 1)
1006            connCls = self.dbConnectionForScheme(scheme)
1007            if oldUri:
1008                conn = connCls.connectionFromOldURI(uri)
1009            else:
1010                conn = connCls.connectionFromURI(uri)
1011        else:
1012            # We just have a name, not a URI
1013            assert uri in self.instanceNames,                      "No SQLObject driver exists under the name %s" % uri
1015            conn = self.instanceNames[uri]
1016        # @@: Do we care if we clobber another connection?
1017        self.cachedURIs[uri] = conn
1018        return conn
1019
1020    def dbConnectionForScheme(self, scheme):
1021        assert scheme in self.schemeBuilders, (
1022               "No SQLObject driver exists for %s (only %s)"
1023               % (scheme, ', '.join(self.schemeBuilders.keys())))
1024        return self.schemeBuilders[scheme]()
1025
1026TheURIOpener = ConnectionURIOpener()
1027
1028registerConnection = TheURIOpener.registerConnection
1029registerConnectionInstance = TheURIOpener.registerConnectionInstance
1030connectionForURI = TheURIOpener.connectionForURI
1031dbConnectionForScheme = TheURIOpener.dbConnectionForScheme
1032
1033# Register DB URI schemas
1034import firebird
1035import maxdb
1036import mssql
1037import mysql
1038import postgres
1039import rdbhost
1040import sqlite
1041import sybase