0001from array import array
0002import datetime
0003from decimal import Decimal
0004import sys
0005import time
0006from types import ClassType, InstanceType, NoneType
0007
0008
0009try:
0010    from mx.DateTime import DateTimeType, DateTimeDeltaType
0011except ImportError:
0012    try:
0013        from DateTime import DateTimeType, DateTimeDeltaType
0014    except ImportError:
0015        DateTimeType = None
0016        DateTimeDeltaType = None
0017
0018try:
0019    import Sybase
0020    NumericType=Sybase.NumericType
0021except ImportError:
0022    NumericType = None
0023
0024
0025########################################
0026## Quoting
0027########################################
0028
0029sqlStringReplace = [
0030    ("'", "''"),
0031    ('\\', '\\\\'),
0032    ('\000', '\\0'),
0033    ('\b', '\\b'),
0034    ('\n', '\\n'),
0035    ('\r', '\\r'),
0036    ('\t', '\\t'),
0037]
0038
0039class ConverterRegistry:
0040
0041    def __init__(self):
0042        self.basic = {}
0043        self.klass = {}
0044
0045    def registerConverter(self, typ, func):
0046        if type(typ) is ClassType:
0047            self.klass[typ] = func
0048        else:
0049            self.basic[typ] = func
0050
0051    def lookupConverter(self, value, default=None):
0052        if type(value) is InstanceType:
0053            # lookup on klasses dict
0054            return self.klass.get(value.__class__, default)
0055        return self.basic.get(type(value), default)
0056
0057converters = ConverterRegistry()
0058registerConverter = converters.registerConverter
0059lookupConverter = converters.lookupConverter
0060
0061def StringLikeConverter(value, db):
0062    if isinstance(value, array):
0063        try:
0064            value = value.tounicode()
0065        except ValueError:
0066            value = value.tostring()
0067    elif isinstance(value, buffer):
0068        value = str(value)
0069
0070    if db in ('mysql', 'postgres', 'rdbhost'):
0071        for orig, repl in sqlStringReplace:
0072            value = value.replace(orig, repl)
0073    elif db in ('sqlite', 'firebird', 'sybase', 'maxdb', 'mssql'):
0074        value = value.replace("'", "''")
0075    else:
0076        assert 0, "Database %s unknown" % db
0077    if db in ('postgres', 'rdbhost') and ('\\' in value):
0078        return "E'%s'" % value
0079    return "'%s'" % value
0080
0081registerConverter(str, StringLikeConverter)
0082registerConverter(unicode, StringLikeConverter)
0083registerConverter(array, StringLikeConverter)
0084registerConverter(buffer, StringLikeConverter)
0085
0086def IntConverter(value, db):
0087    return repr(int(value))
0088
0089registerConverter(int, IntConverter)
0090
0091def LongConverter(value, db):
0092    return str(value)
0093
0094registerConverter(long, LongConverter)
0095
0096if NumericType:
0097    registerConverter(NumericType, IntConverter)
0098
0099def BoolConverter(value, db):
0100    if db in ('postgres', 'rdbhost'):
0101        if value:
0102            return "'t'"
0103        else:
0104            return "'f'"
0105    else:
0106        if value:
0107            return '1'
0108        else:
0109            return '0'
0110
0111registerConverter(bool, BoolConverter)
0112
0113def FloatConverter(value, db):
0114    return repr(value)
0115
0116registerConverter(float, FloatConverter)
0117
0118if DateTimeType:
0119    def DateTimeConverter(value, db):
0120        return "'%s'" % value.strftime("%Y-%m-%d %H:%M:%S.%s")
0121
0122    registerConverter(DateTimeType, DateTimeConverter)
0123
0124    def TimeConverter(value, db):
0125        return "'%s'" % value.strftime("%H:%M:%S")
0126
0127    registerConverter(DateTimeDeltaType, TimeConverter)
0128
0129def NoneConverter(value, db):
0130    return "NULL"
0131
0132registerConverter(NoneType, NoneConverter)
0133
0134def SequenceConverter(value, db):
0135    return "(%s)" % ", ".join([sqlrepr(v, db) for v in value])
0136
0137registerConverter(tuple, SequenceConverter)
0138registerConverter(list, SequenceConverter)
0139registerConverter(dict, SequenceConverter)
0140registerConverter(set, SequenceConverter)
0141registerConverter(frozenset, SequenceConverter)
0142if sys.version_info[:3] < (2, 6, 0): # Module sets was deprecated in Python 2.6
0143   from sets import Set, ImmutableSet
0144   registerConverter(Set, SequenceConverter)
0145   registerConverter(ImmutableSet, SequenceConverter)
0146
0147if hasattr(time, 'struct_time'):
0148    def StructTimeConverter(value, db):
0149        return time.strftime("'%Y-%m-%d %H:%M:%S'", value)
0150
0151    registerConverter(time.struct_time, StructTimeConverter)
0152
0153def DateTimeConverter(value, db):
0154    return "'%04d-%02d-%02d %02d:%02d:%02d.%d'" % (
0155        value.year, value.month, value.day,
0156        value.hour, value.minute, value.second, value.microsecond)
0157
0158registerConverter(datetime.datetime, DateTimeConverter)
0159
0160def DateConverter(value, db):
0161    return "'%04d-%02d-%02d'" % (value.year, value.month, value.day)
0162
0163registerConverter(datetime.date, DateConverter)
0164
0165def TimeConverter(value, db):
0166    return "'%02d:%02d:%02d.%d'" % (value.hour, value.minute, value.second, value.microsecond)
0167
0168registerConverter(datetime.time, TimeConverter)
0169
0170def DecimalConverter(value, db):
0171    # See http://mail.python.org/pipermail/python-dev/2008-March/078189.html
0172    return str(value.to_eng_string()) # Convert to str to work around a bug in Python 2.5.2
0173
0174registerConverter(Decimal, DecimalConverter)
0175
0176def TimedeltaConverter(value, db):
0177
0178    return """INTERVAL '%d days %d seconds'""" %           (value.days, value.seconds)
0180
0181registerConverter(datetime.timedelta, TimedeltaConverter)
0182
0183
0184def sqlrepr(obj, db=None):
0185    try:
0186        reprFunc = obj.__sqlrepr__
0187    except AttributeError:
0188        converter = lookupConverter(obj)
0189        if converter is None:
0190            raise ValueError, "Unknown SQL builtin type: %s for %s" %                     (type(obj), repr(obj))
0192        return converter(obj, db)
0193    else:
0194        return reprFunc(db)
0195
0196
0197def quote_str(s, db):
0198    if db in ('postgres', 'rdbhost') and ('\\' in s):
0199        return "E'%s'" % s
0200    return "'%s'" % s
0201
0202def unquote_str(s):
0203    if s[:2].upper().startswith("E'") and s.endswith("'"):
0204        return s[2:-1]
0205    elif s.startswith("'") and s.endswith("'"):
0206        return s[1:-1]
0207    else:
0208        return s