0001from sqlobject.dbconnection import DBAPI
0002import re
0003from sqlobject import col
0004from sqlobject import sqlbuilder
0005from sqlobject.converters import registerConverter
0006psycopg = None
0007pgdb = None
0008
0009class PostgresConnection(DBAPI):
0010
0011    supportTransactions = True
0012    dbName = 'postgres'
0013    schemes = [dbName, 'postgresql', 'psycopg']
0014
0015    def __init__(self, dsn=None, host=None, port=None, db=None,
0016                 user=None, passwd=None, usePygresql=False, unicodeCols=False,
0017                 **kw):
0018        global psycopg, pgdb
0019        self.usePygresql = usePygresql
0020        if usePygresql:
0021            if pgdb is None:
0022                import pgdb
0023            self.module = pgdb
0024        else:
0025            if psycopg is None:
0026                import psycopg
0027            self.module = psycopg
0028
0029            # Register a converter for psycopg Binary type.
0030            registerConverter(type(psycopg.Binary('')),
0031                              PsycoBinaryConverter)
0032
0033        self.user = user
0034        self.host = host
0035        self.port = port
0036        self.db = db
0037        self.password = passwd
0038        self.dsn_dict = dsn_dict = {}
0039        if host:
0040            dsn_dict["host"] = host
0041        if port:
0042            if usePygresql:
0043                dsn_dict["host"] = "%s:%d" % (host, port)
0044            else:
0045                dsn_dict["port"] = str(port)
0046        if db:
0047            dsn_dict["database"] = db
0048        if user:
0049            dsn_dict["user"] = user
0050        if passwd:
0051            dsn_dict["password"] = passwd
0052        self.use_dsn = dsn is not None
0053        if dsn is None:
0054            if usePygresql:
0055                dsn = ''
0056                if host:
0057                    dsn += host
0058                dsn += ':'
0059                if db:
0060                    dsn += db
0061                dsn += ':'
0062                if user:
0063                    dsn += user
0064                dsn += ':'
0065                if passwd:
0066                    dsn += passwd
0067            else:
0068                dsn = []
0069                if db:
0070                    dsn.append('dbname=%s' % db)
0071                if user:
0072                    dsn.append('user=%s' % user)
0073                if passwd:
0074                    dsn.append('password=%s' % passwd)
0075                if host:
0076                    dsn.append('host=%s' % host)
0077                if port:
0078                    dsn.append('port=%d' % port)
0079                dsn = ' '.join(dsn)
0080        self.dsn = dsn
0081        self.unicodeCols = unicodeCols
0082        DBAPI.__init__(self, **kw)
0083
0084        # Server version cache
0085        self._server_version = None # Not yet initialized
0086
0087    def connectionFromURI(cls, uri):
0088        user, password, host, port, path, args = cls._parseURI(uri)
0089        path = path.strip('/')
0090        return cls(host=host, port=port, db=path, user=user, passwd=password, **args)
0091    connectionFromURI = classmethod(connectionFromURI)
0092
0093    def _setAutoCommit(self, conn, auto):
0094        # psycopg2 does not have an autocommit method.
0095        if hasattr(conn, 'autocommit'):
0096            conn.autocommit(auto)
0097
0098    def makeConnection(self):
0099        try:
0100            if self.use_dsn:
0101                conn = self.module.connect(self.dsn)
0102            else:
0103                conn = self.module.connect(**self.dsn_dict)
0104        except self.module.OperationalError, e:
0105            raise self.module.OperationalError("%s; used connection string %r" % (e, self.dsn))
0106        if self.autoCommit:
0107            # psycopg2 does not have an autocommit method.
0108            if hasattr(conn, 'autocommit'):
0109                conn.autocommit(1)
0110        return conn
0111
0112    def _queryInsertID(self, conn, soInstance, id, names, values):
0113        table = soInstance.sqlmeta.table
0114        idName = soInstance.sqlmeta.idName
0115        sequenceName = getattr(soInstance, '_idSequence',
0116                               '%s_%s_seq' % (table, idName))
0117        c = conn.cursor()
0118        if id is None:
0119            c.execute("SELECT NEXTVAL('%s')" % sequenceName)
0120            id = c.fetchone()[0]
0121        names = [idName] + names
0122        values = [id] + values
0123        q = self._insertSQL(table, names, values)
0124        if self.debug:
0125            self.printDebug(conn, q, 'QueryIns')
0126        c.execute(q)
0127        if self.debugOutput:
0128            self.printDebug(conn, id, 'QueryIns', 'result')
0129        return id
0130
0131    def _queryAddLimitOffset(self, query, start, end):
0132        if not start:
0133            return "%s LIMIT %i" % (query, end)
0134        if not end:
0135            return "%s OFFSET %i" % (query, start)
0136        return "%s LIMIT %i OFFSET %i" % (query, end-start, start)
0137
0138    def createColumn(self, soClass, col):
0139        return col.postgresCreateSQL()
0140
0141    def createIndexSQL(self, soClass, index):
0142        return index.postgresCreateIndexSQL(soClass)
0143
0144    def createIDColumn(self, soClass):
0145        return '%s SERIAL PRIMARY KEY' % soClass.sqlmeta.idName
0146
0147    def dropTable(self, tableName, cascade=False):
0148        if self.server_version[:3] <= "7.2":
0149            cascade=False
0150        self.query("DROP TABLE %s %s" % (tableName,
0151                                         cascade and 'CASCADE' or ''))
0152
0153    def joinSQLType(self, join):
0154        return 'INT NOT NULL'
0155
0156    def tableExists(self, tableName):
0157        result = self.queryOne("SELECT COUNT(relname) FROM pg_class WHERE relname = %s"
0158                               % self.sqlrepr(tableName))
0159        return result[0]
0160
0161    def addColumn(self, tableName, column):
0162        self.query('ALTER TABLE %s ADD COLUMN %s' %
0163                   (tableName,
0164                    column.postgresCreateSQL()))
0165
0166    def delColumn(self, tableName, column):
0167        self.query('ALTER TABLE %s DROP COLUMN %s' %
0168                   (tableName,
0169                    column.dbName))
0170
0171    def columnsFromSchema(self, tableName, soClass):
0172
0173        keyQuery = """
0174        SELECT pg_catalog.pg_get_constraintdef(oid) as condef
0175        FROM pg_catalog.pg_constraint r
0176        WHERE r.conrelid = %s::regclass AND r.contype = 'f'"""
0177
0178        colQuery = """
0179        SELECT a.attname,
0180        pg_catalog.format_type(a.atttypid, a.atttypmod), a.attnotnull,
0181        (SELECT substring(d.adsrc for 128) FROM pg_catalog.pg_attrdef d
0182        WHERE d.adrelid=a.attrelid AND d.adnum = a.attnum)
0183        FROM pg_catalog.pg_attribute a
0184        WHERE a.attrelid =%s::regclass
0185        AND a.attnum > 0 AND NOT a.attisdropped
0186        ORDER BY a.attnum"""
0187
0188        primaryKeyQuery = """
0189        SELECT pg_index.indisprimary,
0190            pg_catalog.pg_get_indexdef(pg_index.indexrelid)
0191        FROM pg_catalog.pg_class c, pg_catalog.pg_class c2,
0192            pg_catalog.pg_index AS pg_index
0193        WHERE c.relname = %s
0194            AND c.oid = pg_index.indrelid
0195            AND pg_index.indexrelid = c2.oid
0196            AND pg_index.indisprimary
0197        """
0198
0199        keyData = self.queryAll(keyQuery % self.sqlrepr(tableName))
0200        keyRE = re.compile(r"\((.+)\) REFERENCES (.+)\(")
0201        keymap = {}
0202
0203        for (condef,) in keyData:
0204            match = keyRE.search(condef)
0205            if match:
0206                field, reftable = match.groups()
0207                keymap[field] = reftable.capitalize()
0208
0209        primaryData = self.queryAll(primaryKeyQuery % self.sqlrepr(tableName))
0210        primaryRE = re.compile(r'CREATE .*? USING .* \((.+?)\)')
0211        primaryKey = None
0212        for isPrimary, indexDef in primaryData:
0213            match = primaryRE.search(indexDef)
0214            assert match, "Unparseable contraint definition: %r" % indexDef
0215            assert primaryKey is None, "Already found primary key (%r), then found: %r" % (primaryKey, indexDef)
0216            primaryKey = match.group(1)
0217        assert primaryKey, "No primary key found in table %r" % tableName
0218        if primaryKey.startswith('"'):
0219            assert primaryKey.endswith('"')
0220            primaryKey = primaryKey[1:-1]
0221
0222        colData = self.queryAll(colQuery % self.sqlrepr(tableName))
0223        results = []
0224        if self.unicodeCols:
0225            client_encoding = self.queryOne("SHOW client_encoding")[0]
0226        for field, t, notnull, defaultstr in colData:
0227            if field == primaryKey:
0228                continue
0229            colClass, kw = self.guessClass(t)
0230            if self.unicodeCols and colClass == col.StringCol:
0231                colClass = col.UnicodeCol
0232                kw['dbEncoding'] = client_encoding
0233            kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
0234            kw['dbName'] = field
0235            kw['notNone'] = notnull
0236            if defaultstr is not None:
0237                kw['default'] = self.defaultFromSchema(colClass, defaultstr)
0238            elif not notnull:
0239                kw['default'] = None
0240            if keymap.has_key(field):
0241                kw['foreignKey'] = keymap[field]
0242            results.append(colClass(**kw))
0243        return results
0244
0245    def guessClass(self, t):
0246        if t.count('int'):
0247            return col.IntCol, {}
0248        elif t.count('varying'):
0249            if '(' in t:
0250                return col.StringCol, {'length': int(t[t.index('(')+1:-1])}
0251            else: # varchar without length in Postgres means any length
0252                return col.StringCol, {}
0253        elif t.startswith('character('):
0254            return col.StringCol, {'length': int(t[t.index('(')+1:-1]),
0255                                   'varchar': False}
0256        elif t == 'text':
0257            return col.StringCol, {}
0258        elif t.startswith('datetime'):
0259            return col.DateTimeCol, {}
0260        elif t.startswith('bool'):
0261            return col.BoolCol, {}
0262        elif t.startswith('bytea'):
0263            return col.BLOBCol, {}
0264        else:
0265            return col.Col, {}
0266
0267    def defaultFromSchema(self, colClass, defaultstr):
0268        """
0269        If the default can be converted to a python constant, convert it.
0270        Otherwise return is as a sqlbuilder constant.
0271        """
0272        if colClass == col.BoolCol:
0273            if defaultstr == 'false':
0274                return False
0275            elif defaultstr == 'true':
0276                return True
0277        return getattr(sqlbuilder.const, defaultstr)
0278
0279    def server_version(self):
0280        if self._server_version is None:
0281            # The result is something like
0282            # ' PostgreSQL 7.2.1 on i686-pc-linux-gnu, compiled by GCC 2.95.4'
0283            server_version = self.queryOne("SELECT version()")[0]
0284            self._server_version = server_version.split()[1]
0285        return self._server_version
0286    server_version = property(server_version)
0287
0288    def createEmptyDatabase(self):
0289        # We have to connect to *some* database, so we'll connect to
0290        # template1, which is a common open database.
0291        # @@: This doesn't use self.use_dsn or self.dsn_dict
0292        if self.usePygresql:
0293            dsn = '%s:template1:%s:%s' % (
0294                self.host or '', self.user or '', self.password or '')
0295        else:
0296            dsn = 'dbname=template1'
0297            if self.user:
0298                dsn += ' user=%s' % self.user
0299            if self.password:
0300                dsn += ' password=%s' % self.password
0301            if self.host:
0302                dsn += ' host=%s' % self.host
0303        conn = self.module.connect(dsn)
0304        cur = conn.cursor()
0305        # We must close the transaction with a commit so that
0306        # the CREATE DATABASE can work (which can't be in a transaction):
0307        cur.execute('COMMIT')
0308        # And we can't use template1 since we're connected to template1,
0309        # so we use template0.  @@: What's the difference between
0310        # these two templates?
0311        cur.execute('CREATE DATABASE %s TEMPLATE=template0' % self.db)
0312        cur.close()
0313        conn.close()
0314
0315
0316
0317# Converter for psycopg Binary type.
0318def PsycoBinaryConverter(value, db):
0319    assert db == 'postgres'
0320    return str(value)