Source code for sqlobject.postgres.pgconnection

from getpass import getuser
import re
from sqlobject import col
from sqlobject import dberrors
from sqlobject import sqlbuilder
from sqlobject.compat import PY2
from sqlobject.converters import registerConverter, sqlrepr
from sqlobject.dbconnection import DBAPI


[docs]class ErrorMessage(str): def __new__(cls, e, append_msg=''): eargs0 = emessage = e.args[0] if e.__module__.startswith('pg8000') \ and isinstance(e.args, tuple) and len(e.args) > 1: # pg8000 =~ 1.12 for Python 3.4 ecode = e.args[2] eerror = emessage = e.args[3] elif e.__module__.startswith('pg8000') and isinstance(eargs0, dict): # pg8000 =~ 1.13 for Python 2.7 # pg8000 for Python 3.5+ ecode = eargs0['C'] eerror = emessage = eargs0['M'] elif e.__module__ == 'pg': # PyGreSQL ecode = e.sqlstate eerror = emessage = e.args[0] elif hasattr(e, 'pgcode'): # psycopg2 or psycopg2.errors ecode = getattr(e, 'pgcode', None) eerror = getattr(e, 'pgerror', None) else: ecode = getattr(e, 'code', None) eerror = getattr(e, 'error', None) obj = str.__new__(cls, emessage + append_msg) obj.code = ecode obj.error = eerror obj.module = e.__module__ obj.exception = e.__class__.__name__ return obj
def _getuser(): # ``getuser()`` on w32 can raise ``ImportError`` # due to absent of ``pwd`` module. try: return getuser() except ImportError: return None
[docs]class PostgresConnection(DBAPI): supportTransactions = True dbName = 'postgres' schemes = [dbName, 'postgresql'] odbc_keywords = ('Server', 'Port', 'UID', 'Password', 'Database') def __init__(self, dsn=None, host=None, port=None, db=None, user=None, password=None, **kw): drivers = kw.pop('driver', None) or 'psycopg' for driver in drivers.split(','): driver = driver.strip() if not driver: continue try: if driver == 'psycopg': import psycopg self.module = psycopg elif driver == 'psycopg2': import psycopg2 self.module = psycopg2 elif driver == 'pygresql': import pgdb self.module = pgdb elif driver in ('py-postgresql', 'pypostgresql'): from postgresql.driver import dbapi20 self.module = dbapi20 elif driver == 'pg8000': import pg8000 self.module = pg8000 elif driver == 'pyodbc': import pyodbc self.module = pyodbc elif driver == 'pypyodbc': import pypyodbc self.module = pypyodbc elif driver == 'odbc': try: import pyodbc except ImportError: import pypyodbc as pyodbc self.module = pyodbc else: raise ValueError( 'Unknown PostgreSQL driver "%s", ' 'expected psycopg, psycopg2, ' 'pygresql, pypostgresql, pg8000, ' 'odbc, pyodbc or pypyodbc' % driver) except ImportError: pass else: break else: raise ImportError( 'Cannot find a PostgreSQL driver, tried %s' % drivers) if driver.startswith('psycopg'): # Register a converter for psycopg Binary type. registerConverter(type(self.module.Binary('')), PsycoBinaryConverter) elif driver in ('pygresql', 'py-postgresql', 'pypostgresql', 'pg8000'): registerConverter(type(self.module.Binary(b'')), PostgresBinaryConverter) elif driver in ('odbc', 'pyodbc', 'pypyodbc'): registerConverter(bytearray, OdbcBinaryConverter) self.db = db self.user = user self.password = password self.host = host self.port = port if driver in ('odbc', 'pyodbc', 'pypyodbc'): self.make_odbc_conn_str(kw.pop('odbcdrv', 'PostgreSQL ANSI'), db, host, port, user, password ) sslmode = kw.pop("sslmode", None) if sslmode: self.odbc_conn_str += ';sslmode=require' else: self.dsn_dict = dsn_dict = {} if host: dsn_dict["host"] = host if port: if driver == 'pygresql': dsn_dict["host"] = "%s:%d" % (host, port) elif driver.startswith('psycopg') and \ psycopg.__version__.split('.')[0] == '1': dsn_dict["port"] = str(port) else: dsn_dict["port"] = port if db: if driver == 'psycopg': dsn_dict["dbname"] = db else: dsn_dict["database"] = db if user: dsn_dict["user"] = user if password: dsn_dict["password"] = password sslmode = kw.pop("sslmode", None) if sslmode: dsn_dict["sslmode"] = sslmode self.use_dsn = dsn is not None if dsn is None: if driver == 'pygresql': dsn = '' if host: dsn += host dsn += ':' if db: dsn += db dsn += ':' if user: dsn += user dsn += ':' if password: dsn += password else: dsn = [] if db: dsn.append('dbname=%s' % db) if user: dsn.append('user=%s' % user) if password: dsn.append('password=%s' % password) if host: dsn.append('host=%s' % host) if port: dsn.append('port=%d' % port) if sslmode: dsn.append('sslmode=%s' % sslmode) dsn = ' '.join(dsn) if driver in ('py-postgresql', 'pypostgresql'): if host and host.startswith('/'): dsn_dict["host"] = dsn_dict["port"] = None dsn_dict["unix"] = host else: if "unix" in dsn_dict: del dsn_dict["unix"] if driver == 'pg8000': if host and host.startswith('/'): dsn_dict["host"] = None dsn_dict["unix_sock"] = host if user is None: dsn_dict["user"] = _getuser() self.dsn = dsn self.driver = driver self.unicodeCols = kw.pop('unicodeCols', False) self.schema = kw.pop('schema', None) self.dbEncoding = kw.pop("charset", None) DBAPI.__init__(self, **kw) @classmethod def _connectionFromParams(cls, user, password, host, port, path, args): path = path.strip('/') if (host is None) and path.count('/'): # Non-default unix socket path_parts = path.split('/') host = '/' + '/'.join(path_parts[:-1]) path = path_parts[-1] return cls(host=host, port=port, db=path, user=user, password=password, **args) def _setAutoCommit(self, conn, auto): # psycopg2 does not have an autocommit method. if hasattr(conn, 'autocommit'): try: conn.autocommit(auto) except TypeError: conn.autocommit = auto
[docs] def makeConnection(self): try: if self.driver in ('odbc', 'pyodbc', 'pypyodbc'): self.debugWriter.write( "ODBC connect string: " + self.odbc_conn_str) conn = self.module.connect(self.odbc_conn_str) elif self.use_dsn: conn = self.module.connect(self.dsn) else: conn = self.module.connect(**self.dsn_dict) except self.module.OperationalError as e: raise dberrors.OperationalError( ErrorMessage(e, "used connection string %r" % self.dsn)) # For printDebug in _executeRetry self._connectionNumbers[id(conn)] = self._connectionCount if self.autoCommit: self._setAutoCommit(conn, 1) c = conn.cursor() if self.schema: self._executeRetry(conn, c, "SET search_path TO " + self.schema) dbEncoding = self.dbEncoding if dbEncoding: if self.driver in ('odbc', 'pyodbc'): conn.setdecoding(self.module.SQL_CHAR, encoding=dbEncoding) conn.setdecoding(self.module.SQL_WCHAR, encoding=dbEncoding) if PY2: conn.setencoding(str, encoding=dbEncoding) conn.setencoding(unicode, encoding=dbEncoding) # noqa else: conn.setencoding(encoding=dbEncoding) self._executeRetry(conn, c, "SET client_encoding TO '%s'" % dbEncoding) c.close() return conn
def _executeRetry(self, conn, cursor, query): if self.debug: self.printDebug(conn, query, 'QueryR') dbEncoding = self.dbEncoding if dbEncoding and isinstance(query, bytes) and ( self.driver == 'pg8000'): query = query.decode(dbEncoding) try: return cursor.execute(query) except self.module.OperationalError as e: raise dberrors.OperationalError(ErrorMessage(e)) except self.module.IntegrityError as e: msg = ErrorMessage(e) if getattr(msg, 'code', -1) == '23505' or \ getattr(e, 'code', -1) == '23505' or \ getattr(e, 'pgcode', -1) == '23505' or \ getattr(e, 'sqlstate', -1) == '23505' or \ e.args[0] == '23505': raise dberrors.DuplicateEntryError(msg) else: raise dberrors.IntegrityError(msg) except self.module.InternalError as e: raise dberrors.InternalError(ErrorMessage(e)) except self.module.ProgrammingError as e: msg = ErrorMessage(e) if ( (len(e.args) > 2) and (e.args[1] == 'ERROR') and (e.args[2] == '23505')) \ or ((len(e.args) >= 2) and (e.args[1] == '23505')): raise dberrors.DuplicateEntryError(msg) else: raise dberrors.ProgrammingError(msg) except self.module.DataError as e: raise dberrors.DataError(ErrorMessage(e)) except self.module.NotSupportedError as e: raise dberrors.NotSupportedError(ErrorMessage(e)) except self.module.DatabaseError as e: msg = ErrorMessage(e) if 'duplicate key value violates unique constraint' in msg: raise dberrors.DuplicateEntryError(msg) else: raise dberrors.DatabaseError(msg) except self.module.InterfaceError as e: raise dberrors.InterfaceError(ErrorMessage(e)) except self.module.Warning as e: raise Warning(ErrorMessage(e)) except self.module.Error as e: raise dberrors.Error(ErrorMessage(e)) def _queryInsertID(self, conn, soInstance, id, names, values): table = soInstance.sqlmeta.table idName = soInstance.sqlmeta.idName c = conn.cursor() if id is None and self.driver in ('py-postgresql', 'pypostgresql'): sequenceName = soInstance.sqlmeta.idSequence or \ '%s_%s_seq' % (table, idName) self._executeRetry(conn, c, "SELECT NEXTVAL('%s')" % sequenceName) id = c.fetchone()[0] if id is not None: names = [idName] + names values = [id] + values if names and values: q = self._insertSQL(table, names, values) else: q = "INSERT INTO %s DEFAULT VALUES" % table if id is None: q += " RETURNING " + idName if self.debug: self.printDebug(conn, q, 'QueryIns') self._executeRetry(conn, c, q) if id is None: id = c.fetchone()[0] c.close() if self.debugOutput: self.printDebug(conn, id, 'QueryIns', 'result') return id @classmethod def _queryAddLimitOffset(cls, query, start, end): if not start: return "%s LIMIT %i" % (query, end) if not end: return "%s OFFSET %i" % (query, start) return "%s LIMIT %i OFFSET %i" % (query, end - start, start)
[docs] def createColumn(self, soClass, col): return col.postgresCreateSQL()
[docs] def createReferenceConstraint(self, soClass, col): return col.postgresCreateReferenceConstraint()
[docs] def createIndexSQL(self, soClass, index): return index.postgresCreateIndexSQL(soClass)
[docs] def createIDColumn(self, soClass): if soClass.sqlmeta.idType is int: if soClass.sqlmeta.idSize in ('TINY', 'SMALL'): key_type = 'SMALLSERIAL' elif soClass.sqlmeta.idSize in ('MEDIUM', None): key_type = 'SERIAL' elif soClass.sqlmeta.idSize == 'BIG': key_type = 'BIGSERIAL' else: raise ValueError( "sqlmeta.idSize must be 'TINY', 'SMALL', 'MEDIUM', 'BIG' " "or None, not %r" % soClass.sqlmeta.idSize) elif soClass.sqlmeta.idType is str: key_type = "TEXT" else: raise TypeError('sqlmeta.idType must be int or str, not %r' % soClass.sqlmeta.idType) return '%s %s PRIMARY KEY' % (soClass.sqlmeta.idName, key_type)
[docs] def dropTable(self, tableName, cascade=False): self.query("DROP TABLE %s %s" % (tableName, cascade and 'CASCADE' or ''))
[docs] def joinSQLType(self, join): return 'INT NOT NULL'
[docs] def tableExists(self, tableName): result = self.queryOne( "SELECT COUNT(relname) FROM pg_class WHERE relname = %s" % self.sqlrepr(tableName)) return result[0]
[docs] def addColumn(self, tableName, column): self.query('ALTER TABLE %s ADD COLUMN %s' % (tableName, column.postgresCreateSQL()))
[docs] def delColumn(self, sqlmeta, column): self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table, column.dbName))
[docs] def columnsFromSchema(self, tableName, soClass): keyQuery = """ SELECT pg_catalog.pg_get_constraintdef(oid) as condef FROM pg_catalog.pg_constraint r WHERE r.conrelid = %s::regclass AND r.contype = 'f'""" colQuery = """ SELECT a.attname, pg_catalog.format_type(a.atttypid, a.atttypmod), a.attnotnull, (SELECT substring(pg_catalog.pg_get_expr(d.adbin, d.adrelid) for 128) FROM pg_catalog.pg_attrdef d WHERE d.adrelid=a.attrelid AND d.adnum = a.attnum) FROM pg_catalog.pg_attribute a WHERE a.attrelid =%s::regclass AND a.attnum > 0 AND NOT a.attisdropped ORDER BY a.attnum""" primaryKeyQuery = """ SELECT pg_index.indisprimary, pg_catalog.pg_get_indexdef(pg_index.indexrelid) FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, pg_catalog.pg_index AS pg_index WHERE c.relname = %s AND c.oid = pg_index.indrelid AND pg_index.indexrelid = c2.oid AND pg_index.indisprimary """ otherKeyQuery = """ SELECT pg_index.indisprimary, pg_catalog.pg_get_indexdef(pg_index.indexrelid) FROM pg_catalog.pg_class c, pg_catalog.pg_class c2, pg_catalog.pg_index AS pg_index WHERE c.relname = %s AND c.oid = pg_index.indrelid AND pg_index.indexrelid = c2.oid AND NOT pg_index.indisprimary """ keyData = self.queryAll(keyQuery % self.sqlrepr(tableName)) keyRE = re.compile(r"\((.+)\) REFERENCES (.+)\(") keymap = {} for (condef,) in keyData: match = keyRE.search(condef) if match: field, reftable = match.groups() keymap[field] = reftable.capitalize() primaryData = self.queryAll(primaryKeyQuery % self.sqlrepr(tableName)) primaryRE = re.compile(r'CREATE .*? USING .* \((.+?)\)') primaryKey = None for isPrimary, indexDef in primaryData: match = primaryRE.search(indexDef) assert match, "Unparseable contraint definition: %r" % indexDef assert primaryKey is None, \ "Already found primary key (%r), " \ "then found: %r" % (primaryKey, indexDef) primaryKey = match.group(1) if primaryKey is None: # VIEWs don't have PRIMARY KEYs - accept help from user primaryKey = soClass.sqlmeta.idName assert primaryKey, "No primary key found in table %r" % tableName if primaryKey.startswith('"'): assert primaryKey.endswith('"') primaryKey = primaryKey[1:-1] otherData = self.queryAll(otherKeyQuery % self.sqlrepr(tableName)) otherRE = primaryRE otherKeys = [] for isPrimary, indexDef in otherData: match = otherRE.search(indexDef) assert match, "Unparseable constraint definition: %r" % indexDef otherKey = match.group(1) if otherKey.startswith('"'): assert otherKey.endswith('"') otherKey = otherKey[1:-1] otherKeys.append(otherKey) colData = self.queryAll(colQuery % self.sqlrepr(tableName)) results = [] if self.unicodeCols: client_encoding = self.queryOne("SHOW client_encoding")[0] for field, t, notnull, defaultstr in colData: if field == primaryKey: continue if field in keymap: colClass = col.ForeignKey kw = {'foreignKey': soClass.sqlmeta.style. dbTableToPythonClass(keymap[field])} name = soClass.sqlmeta.style.dbColumnToPythonAttr(field) if name.endswith('ID'): name = name[:-2] kw['name'] = name else: colClass, kw = self.guessClass(t) if self.unicodeCols and colClass is col.StringCol: colClass = col.UnicodeCol kw['dbEncoding'] = client_encoding kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field) kw['dbName'] = field kw['notNone'] = notnull if defaultstr is not None: kw['default'] = self.defaultFromSchema(colClass, defaultstr) elif not notnull: kw['default'] = None if field in otherKeys: kw['alternateID'] = True results.append(colClass(**kw)) return results
[docs] def guessClass(self, t): if t.count('point'): # poINT before INT return col.StringCol, {} elif t.count('int'): return col.IntCol, {} elif t.count('varying') or t.count('varchar'): if '(' in t: return col.StringCol, {'length': int(t[t.index('(') + 1:-1])} else: # varchar without length in Postgres means any length return col.StringCol, {} elif t.startswith('character('): return col.StringCol, {'length': int(t[t.index('(') + 1:-1]), 'varchar': False} elif t.count('float') or t.count('real') or t.count('double'): return col.FloatCol, {} elif t == 'text': return col.StringCol, {} elif t.startswith('timestamp'): return col.DateTimeCol, {} elif t.startswith('datetime'): return col.DateTimeCol, {} elif t.startswith('date'): return col.DateCol, {} elif t.startswith('bool'): return col.BoolCol, {} elif t.startswith('bytea'): return col.BLOBCol, {} else: return col.Col, {}
[docs] def defaultFromSchema(self, colClass, defaultstr): """ If the default can be converted to a python constant, convert it. Otherwise return is as a sqlbuilder constant. """ if colClass == col.BoolCol: if defaultstr == 'false': return False elif defaultstr == 'true': return True return getattr(sqlbuilder.const, defaultstr)
def _createOrDropDatabase(self, op="CREATE"): # We have to connect to *some* database, so we'll connect to # template1, which is a common open database. # @@: This doesn't use self.use_dsn or self.dsn_dict if self.driver == 'pygresql': dsn = '%s:template1:%s:%s' % ( self.host or '', self.user or '', self.password or '') else: dsn = 'dbname=template1' if self.user: dsn += ' user=%s' % self.user if self.password: dsn += ' password=%s' % self.password if self.host: dsn += ' host=%s' % self.host conn = self.module.connect(dsn) cur = conn.cursor() # We must close the transaction with a commit so that # the CREATE DATABASE can work (which can't be in a transaction): try: self._executeRetry(conn, cur, 'COMMIT') self._executeRetry(conn, cur, '%s DATABASE %s' % (op, self.db)) finally: cur.close() conn.close()
[docs] def listTables(self): return [v[0] for v in self.queryAll( """SELECT c.relname FROM pg_catalog.pg_class c LEFT JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind IN ('r','') AND n.nspname NOT IN ('pg_catalog', 'pg_toast') AND pg_catalog.pg_table_is_visible(c.oid)""")]
[docs] def listDatabases(self): return [v[0] for v in self.queryAll("SELECT datname FROM pg_database")]
[docs] def createEmptyDatabase(self): self._createOrDropDatabase()
[docs] def dropDatabase(self): self._createOrDropDatabase(op="DROP")
# Converter for Binary types
[docs]def PsycoBinaryConverter(value, db): assert db == 'postgres' return str(value)
if PY2: def escape_bytea(value): return ''.join( ['\\' + (x[1:].rjust(3, '0')) for x in (oct(ord(c)) for c in value)] ) else:
[docs] def escape_bytea(value): return ''.join( ['\\' + (x[2:].rjust(3, '0')) for x in (oct(ord(c)) for c in value.decode('latin1'))] )
[docs]def PostgresBinaryConverter(value, db): assert db == 'postgres' return sqlrepr(escape_bytea(value), db)
[docs]def OdbcBinaryConverter(value, db): assert db == 'postgres' value = bytes(value) if not PY2: value = value.decode('latin1') return value