from sqlobject import col, dberrors
from sqlobject.compat import PY2, string_type
from sqlobject.converters import registerConverter, StringLikeConverter
from sqlobject.dbconnection import DBAPI
[docs]class ErrorMessage(str):
def __new__(cls, e, append_msg=''):
if e.__module__ == 'cymysql.err':
if isinstance(e.errmsg, string_type):
errmsg = e.errmsg
else:
errmsg = e.errmsg.reason
errcode = e.errno
else:
if len(e.args) > 1:
errmsg = e.args[1]
else:
errmsg = ''
try:
errcode = int(e.args[0])
except ValueError:
errcode = e.args[0]
obj = str.__new__(cls, errmsg + append_msg)
obj.code = errcode
obj.module = e.__module__
obj.exception = e.__class__.__name__
return obj
mysql_Bin = None
[docs]class MySQLConnection(DBAPI):
supportTransactions = False
dbName = 'mysql'
schemes = [dbName]
odbc_keywords = ('Server', 'Port', 'UID', 'Password', 'Database')
def __init__(self, db, user, password='', host='localhost', port=0, **kw):
drivers = kw.pop('driver', None) or 'mysqldb'
for driver in drivers.split(','):
driver = driver.strip().lower()
if not driver:
continue
try:
if driver in ('mysqldb', 'pymysql'):
if driver == 'pymysql':
import pymysql
pymysql.install_as_MySQLdb()
import MySQLdb
if driver == 'mysqldb':
if MySQLdb.version_info[:3] < (1, 2, 2):
raise ValueError(
'SQLObject requires MySQLdb 1.2.2 or later')
import MySQLdb.constants.CR
import MySQLdb.constants.ER
self.module = MySQLdb
if driver == 'mysqldb':
self.CR_SERVER_GONE_ERROR = \
MySQLdb.constants.CR.SERVER_GONE_ERROR
self.CR_SERVER_LOST = \
MySQLdb.constants.CR.SERVER_LOST
else:
self.CR_SERVER_GONE_ERROR = \
MySQLdb.constants.CR.CR_SERVER_GONE_ERROR
self.CR_SERVER_LOST = \
MySQLdb.constants.CR.CR_SERVER_LOST
self.ER_DUP_ENTRY = MySQLdb.constants.ER.DUP_ENTRY
elif driver == 'cymysql':
import cymysql
import cymysql.constants.CR
import cymysql.constants.ER
self.module = cymysql
self.CR_SERVER_GONE_ERROR = \
cymysql.constants.CR.CR_SERVER_GONE_ERROR
self.CR_SERVER_LOST = \
cymysql.constants.CR.CR_SERVER_LOST
self.ER_DUP_ENTRY = cymysql.constants.ER.DUP_ENTRY
elif driver in ('connector', 'connector-python'):
import mysql.connector
self.module = mysql.connector
self.CR_SERVER_GONE_ERROR = \
mysql.connector.errorcode.CR_SERVER_GONE_ERROR
self.CR_SERVER_LOST = \
mysql.connector.errorcode.CR_SERVER_LOST
self.ER_DUP_ENTRY = mysql.connector.errorcode.ER_DUP_ENTRY
if driver == 'connector-python':
self.connector_type = 'mysql.connector-python'
else:
self.connector_type = 'mysql.connector'
elif driver == 'mariadb':
import mariadb
self.module = mariadb
elif driver == 'pyodbc':
import pyodbc
self.module = pyodbc
elif driver == 'pypyodbc':
import pypyodbc
self.module = pypyodbc
elif driver == 'odbc':
try:
import pyodbc
except ImportError:
import pypyodbc as pyodbc
self.module = pyodbc
else:
raise ValueError(
'Unknown MySQL driver "%s", '
'expected mysqldb, connector, connector-python, '
'pymysql, cymysql, mariadb, '
'odbc, pyodbc or pypyodbc' % driver)
except ImportError:
pass
else:
break
else:
raise ImportError(
'Cannot find a MySQL driver, tried %s' % drivers)
self.host = host
self.port = port or 3306
self.db = db
self.user = user
self.password = password
self.kw = {}
for key in ("unix_socket", "init_command",
"read_default_file", "read_default_group", "conv"):
if key in kw:
self.kw[key] = kw.pop(key)
for key in ("connect_timeout", "compress", "named_pipe", "use_unicode",
"client_flag", "local_infile"):
if key in kw:
self.kw[key] = int(kw.pop(key))
if driver in ('connector', 'connector-python'):
for key in ("ssl_key", "ssl_cert", "ssl_ca", "ssl_capath"):
if key in kw:
self.kw[key] = kw.pop(key)
else:
for key in ("ssl_key", "ssl_cert", "ssl_ca", "ssl_capath"):
if key in kw:
if "ssl" not in self.kw:
self.kw["ssl"] = {}
self.kw["ssl"][key[4:]] = kw.pop(key)
if "charset" in kw:
self.dbEncoding = self.kw["charset"] = kw.pop("charset")
else:
self.dbEncoding = None
self.driver = driver
if driver in ('mariadb', 'odbc', 'pyodbc', 'pypyodbc'):
self.CR_SERVER_GONE_ERROR = 2006
self.CR_SERVER_LOST = 2013
self.ER_DUP_ENTRY = '23000'
if driver in ('odbc', 'pyodbc', 'pypyodbc'):
self.make_odbc_conn_str(kw.pop('odbcdrv',
'MySQL ODBC 5.3 ANSI Driver'),
db, host, port, user, password
)
elif driver == 'mariadb':
self.kw.pop("charset", None)
elif driver in ('connector', 'connector-python'):
registerConverter(bytes, ConnectorBytesConverter)
global mysql_Bin
if not PY2 and mysql_Bin is None:
mysql_Bin = self.module.Binary
self.module.Binary = lambda x: mysql_Bin(x).decode(
'ascii', errors='surrogateescape')
self._server_version = None
self._can_use_microseconds = None
self._can_use_json_funcs = None
DBAPI.__init__(self, **kw)
@classmethod
def _connectionFromParams(cls, user, password, host, port, path, args):
return cls(db=path.strip('/'),
user=user or '', password=password or '',
host=host or 'localhost', port=port, **args)
[docs] def makeConnection(self):
dbEncoding = self.dbEncoding
if dbEncoding:
if self.driver in ('mysqldb', 'pymysql'):
from MySQLdb.connections import Connection
if not hasattr(Connection, 'set_character_set'):
# monkeypatch pre MySQLdb 1.2.1
def character_set_name(self):
return dbEncoding + '_' + dbEncoding
Connection.character_set_name = character_set_name
if self.driver in ('connector', 'connector-python'):
self.kw['consume_results'] = True
try:
if self.driver in ('odbc', 'pyodbc', 'pypyodbc'):
self.debugWriter.write(
"ODBC connect string: " + self.odbc_conn_str)
conn = self.module.connect(self.odbc_conn_str)
else:
conn = self.module.connect(
host=self.host, port=self.port, db=self.db,
user=self.user, passwd=self.password, **self.kw)
if self.driver == 'mariadb':
# Attempt to reconnect.
# This setting is persistent due to ``auto_reconnect``.
# mariadb doesn't implement ping(True)
conn.auto_reconnect = True
conn.ping()
else:
# Attempt to reconnect. This setting is persistent.
conn.ping(True)
except self.module.OperationalError as e:
conninfo = ("; used connection string: "
"host=%(host)s, port=%(port)s, "
"db=%(db)s, user=%(user)s" % self.__dict__)
raise dberrors.OperationalError(ErrorMessage(e, conninfo))
self._setAutoCommit(conn, bool(self.autoCommit))
if dbEncoding:
if self.driver in ('odbc', 'pyodbc'):
conn.setdecoding(self.module.SQL_CHAR, encoding=dbEncoding)
conn.setdecoding(self.module.SQL_WCHAR, encoding=dbEncoding)
if PY2:
conn.setencoding(str, encoding=dbEncoding)
conn.setencoding(unicode, encoding=dbEncoding) # noqa
else:
conn.setencoding(encoding=dbEncoding)
elif hasattr(conn, 'set_character_set'):
conn.set_character_set(dbEncoding)
elif hasattr(conn, 'query'):
# works along with monkeypatching code above
conn.query("SET NAMES %s" % dbEncoding)
return conn
def _setAutoCommit(self, conn, auto):
if hasattr(conn, 'autocommit'):
try:
conn.autocommit(auto)
except TypeError:
# mysql-connector{-python} has autocommit as a property
conn.autocommit = auto
def _force_reconnect(self, conn):
if self.driver in ('pymysql', 'cymysql'):
conn.ping(True)
self._setAutoCommit(conn, bool(self.autoCommit))
if self.dbEncoding:
conn.query("SET NAMES %s" % self.dbEncoding)
def _executeRetry(self, conn, cursor, query):
if self.debug:
self.printDebug(conn, query, 'QueryR')
dbEncoding = self.dbEncoding
if dbEncoding and not isinstance(query, bytes) and (
self.driver in ('mysqldb', 'connector', 'connector-python',
'mariadb')):
query = query.encode(dbEncoding, 'surrogateescape')
# When a server connection is lost and a query is attempted, most of
# the time the query will raise a SERVER_LOST exception, then at the
# second attempt to execute it, the mysql lib will reconnect and
# succeed. However is a few cases, the first attempt raises the
# SERVER_GONE exception, the second attempt the SERVER_LOST exception
# and only the third succeeds. Thus the 3 in the loop count.
# If it doesn't reconnect even after 3 attempts, while the database is
# up and running, it is because a 5.0.3 (or newer) server is used
# which no longer permits autoreconnects by default. In that case a
# reconnect flag must be set when making the connection to indicate
# that autoreconnecting is desired. In MySQLdb 1.2.2 or newer this is
# done by calling ping(True) on the connection.
# [PC]yMySQL need explicit reconnect
# each time we detect connection timeout.
for count in range(3):
try:
return cursor.execute(query)
except self.module.OperationalError as e:
if e.args[0] in (self.CR_SERVER_GONE_ERROR,
self.CR_SERVER_LOST):
if count == 2:
raise dberrors.OperationalError(ErrorMessage(e))
if self.debug:
self.printDebug(conn, str(e), 'ERROR')
if self.driver in ('pymysql', 'cymysql'):
self._force_reconnect(conn)
else:
raise dberrors.OperationalError(ErrorMessage(e))
except self.module.IntegrityError as e:
msg = ErrorMessage(e)
if e.args[0] == self.ER_DUP_ENTRY:
raise dberrors.DuplicateEntryError(msg)
elif isinstance(e.args[0], str) \
and e.args[0].startswith('Duplicate'):
raise dberrors.DuplicateEntryError(msg)
else:
raise dberrors.IntegrityError(msg)
except self.module.InternalError as e:
raise dberrors.InternalError(ErrorMessage(e))
except self.module.ProgrammingError as e:
if e.args[0] is not None:
raise dberrors.ProgrammingError(ErrorMessage(e))
except self.module.DataError as e:
raise dberrors.DataError(ErrorMessage(e))
except self.module.NotSupportedError as e:
raise dberrors.NotSupportedError(ErrorMessage(e))
except self.module.DatabaseError as e:
raise dberrors.DatabaseError(ErrorMessage(e))
except self.module.InterfaceError as e:
raise dberrors.InterfaceError(ErrorMessage(e))
except self.module.Warning as e:
raise Warning(ErrorMessage(e))
except self.module.Error as e:
raise dberrors.Error(ErrorMessage(e))
def _queryInsertID(self, conn, soInstance, id, names, values):
table = soInstance.sqlmeta.table
idName = soInstance.sqlmeta.idName
c = conn.cursor()
if id is not None:
names = [idName] + names
values = [id] + values
q = self._insertSQL(table, names, values)
if self.debug:
self.printDebug(conn, q, 'QueryIns')
self._executeRetry(conn, c, q)
if id is None:
try:
id = c.lastrowid
except AttributeError:
try:
id = c.insert_id
except AttributeError:
self._executeRetry(conn, c, "SELECT LAST_INSERT_ID();")
id = c.fetchone()[0]
else:
id = c.insert_id()
c.close()
if self.debugOutput:
self.printDebug(conn, id, 'QueryIns', 'result')
return id
@classmethod
def _queryAddLimitOffset(cls, query, start, end):
if not start:
return "%s LIMIT %i" % (query, end)
if not end:
return "%s LIMIT %i, -1" % (query, start)
return "%s LIMIT %i, %i" % (query, start, end - start)
[docs] def createReferenceConstraint(self, soClass, col):
return col.mysqlCreateReferenceConstraint()
[docs] def createColumn(self, soClass, col):
return col.mysqlCreateSQL(self)
[docs] def createIndexSQL(self, soClass, index):
return index.mysqlCreateIndexSQL(soClass)
[docs] def createIDColumn(self, soClass):
if soClass.sqlmeta.idType is str:
return '%s TEXT PRIMARY KEY' % soClass.sqlmeta.idName
if soClass.sqlmeta.idType is not int:
raise TypeError('sqlmeta.idType must be int or str, not %r'
% soClass.sqlmeta.idType)
if soClass.sqlmeta.idSize is None:
mysql_int_type = 'INT'
elif soClass.sqlmeta.idSize in ('TINY', 'SMALL', 'MEDIUM', 'BIG'):
mysql_int_type = '%sINT' % soClass.sqlmeta.idSize
else:
raise ValueError(
"sqlmeta.idSize must be 'TINY', 'SMALL', 'MEDIUM', 'BIG' "
"or None, not %r" % soClass.sqlmeta.idSize)
return '%s %s PRIMARY KEY AUTO_INCREMENT' \
% (soClass.sqlmeta.idName, mysql_int_type)
[docs] def joinSQLType(self, join):
return 'INT NOT NULL'
[docs] def tableExists(self, tableName):
try:
# Use DESCRIBE instead of SHOW TABLES because SHOW TABLES
# assumes there is a default database selected
# which is not always True (for an embedded application, e.g.)
self.query('DESCRIBE %s' % (tableName))
return True
except dberrors.ProgrammingError as e:
if e.args[0].code in (1146, '42S02'): # ER_NO_SUCH_TABLE
return False
if self.driver == 'mariadb':
return False
raise
[docs] def addColumn(self, tableName, column):
self.query('ALTER TABLE %s ADD COLUMN %s' %
(tableName,
column.mysqlCreateSQL(self)))
[docs] def delColumn(self, sqlmeta, column):
self.query('ALTER TABLE %s DROP COLUMN %s' % (sqlmeta.table,
column.dbName))
[docs] def columnsFromSchema(self, tableName, soClass):
colData = self.queryAll("SHOW COLUMNS FROM %s"
% tableName)
results = []
for field, t, nullAllowed, key, default, extra in colData:
if field == soClass.sqlmeta.idName:
continue
colClass, kw = self.guessClass(t)
if self.kw.get('use_unicode') and colClass is col.StringCol:
colClass = col.UnicodeCol
if self.dbEncoding:
kw['dbEncoding'] = self.dbEncoding
kw['name'] = soClass.sqlmeta.style.dbColumnToPythonAttr(field)
kw['dbName'] = field
# Since MySQL 5.0, 'NO' is returned in the NULL column
# (SQLObject expected '')
kw['notNone'] = (nullAllowed.upper() != 'YES' and True or False)
if not PY2 and isinstance(t, bytes):
t = t.decode('ascii')
if default and t.startswith('int'):
kw['default'] = int(default)
elif default and t.startswith('float'):
kw['default'] = float(default)
elif default == 'CURRENT_TIMESTAMP' and t == 'timestamp':
kw['default'] = None
elif default and colClass is col.BoolCol:
kw['default'] = int(default) and True or False
else:
kw['default'] = default
# @@ skip key...
# @@ skip extra...
results.append(colClass(**kw))
return results
[docs] def guessClass(self, t):
if not PY2 and isinstance(t, bytes):
t = t.decode('ascii')
if t.startswith('int'):
return col.IntCol, {}
elif t.startswith('enum'):
values = []
for i in t[5:-1].split(','): # take the enum() off and split
values.append(i[1:-1]) # remove the surrounding \'
return col.EnumCol, {'enumValues': values}
elif t.startswith('double'):
return col.FloatCol, {}
elif t.startswith('varchar'):
colType = col.StringCol
if self.kw.get('use_unicode', False):
colType = col.UnicodeCol
if t.endswith('binary'):
return colType, {'length': int(t[8:-8]),
'char_binary': True}
else:
return colType, {'length': int(t[8:-1])}
elif t.startswith('char'):
if t.endswith('binary'):
return col.StringCol, {'length': int(t[5:-8]),
'varchar': False,
'char_binary': True}
else:
return col.StringCol, {'length': int(t[5:-1]),
'varchar': False}
elif t.startswith('datetime'):
return col.DateTimeCol, {}
elif t.startswith('date'):
return col.DateCol, {}
elif t.startswith('time'):
return col.TimeCol, {}
elif t.startswith('timestamp'):
return col.TimestampCol, {}
elif t.startswith('bool'):
return col.BoolCol, {}
elif t.startswith('tinyblob'):
return col.BLOBCol, {"length": 2 ** 8 - 1}
elif t.startswith('tinytext'):
return col.StringCol, {"length": 2 ** 8 - 1, "varchar": True}
elif t.startswith('blob'):
return col.BLOBCol, {"length": 2 ** 16 - 1}
elif t.startswith('text'):
return col.StringCol, {"length": 2 ** 16 - 1, "varchar": True}
elif t.startswith('mediumblob'):
return col.BLOBCol, {"length": 2 ** 24 - 1}
elif t.startswith('mediumtext'):
return col.StringCol, {"length": 2 ** 24 - 1, "varchar": True}
elif t.startswith('longblob'):
return col.BLOBCol, {"length": 2 ** 32}
elif t.startswith('longtext'):
return col.StringCol, {"length": 2 ** 32, "varchar": True}
else:
return col.Col, {}
[docs] def listTables(self):
return _decodeBytearrays(self.queryAll("SHOW TABLES"))
[docs] def listDatabases(self):
return _decodeBytearrays(self.queryAll("SHOW DATABASES"))
def _createOrDropDatabase(self, op="CREATE"):
self.query('%s DATABASE %s' % (op, self.db))
[docs] def createEmptyDatabase(self):
self._createOrDropDatabase()
[docs] def dropDatabase(self):
self._createOrDropDatabase(op="DROP")
[docs] def server_version(self):
if self._server_version is not None:
return self._server_version
try:
server_version = self.queryOne("SELECT VERSION()")[0]
server_version = server_version.split('-', 1)
db_tag = "MySQL"
if len(server_version) == 2:
if "MariaDB" in server_version[1]:
db_tag = "MariaDB"
server_version = server_version[0]
server_version = tuple(int(v) for v in server_version.split('.'))
server_version = (server_version, db_tag)
except Exception:
server_version = None # unknown
self._server_version = server_version
return server_version
[docs] def can_use_microseconds(self):
if self._can_use_microseconds is not None:
return self._can_use_microseconds
server_version = self.server_version()
if server_version is None:
return None
server_version, db_tag = server_version
if db_tag == "MariaDB":
can_use_microseconds = (server_version >= (5, 3, 0))
else: # MySQL
can_use_microseconds = (server_version >= (5, 6, 4))
self._can_use_microseconds = can_use_microseconds
return can_use_microseconds
[docs] def can_use_json_funcs(self):
if self._can_use_json_funcs is not None:
return self._can_use_json_funcs
server_version = self.server_version()
if server_version is None:
return None
server_version, db_tag = server_version
if db_tag == "MariaDB":
can_use_json_funcs = (server_version >= (10, 2, 7))
else: # MySQL
can_use_json_funcs = (server_version >= (5, 7, 0))
self._can_use_json_funcs = can_use_json_funcs
return can_use_json_funcs
[docs]def ConnectorBytesConverter(value, db):
if not PY2:
# For PY2 this converter is called also for SQLite
assert db == 'mysql'
value = value.decode('latin1')
return StringLikeConverter(value, db)
def _decodeBytearrays(v_list):
if not v_list:
return []
if not PY2 and isinstance(v_list[0][0], bytearray):
return [v[0].decode('ascii') for v in v_list]
return [v[0] for v in v_list]