0001#!/usr/bin/env python
0002import optparse
0003import fnmatch
0004import re
0005import os
0006import sys
0007import textwrap
0008import warnings
0009try:
0010    from paste import pyconfig
0011    from paste import CONFIG
0012except ImportError, e:
0013    pyconfig = None
0014    CONFIG = {}
0015import time
0016
0017import sqlobject
0018from sqlobject import col
0019from sqlobject.util import moduleloader
0020from sqlobject.declarative import DeclarativeMeta
0021
0022# It's not very unsafe to use tempnam like we are doing:
0023warnings.filterwarnings(
0024    'ignore', 'tempnam is a potential security risk.*',
0025    RuntimeWarning, '.*command', 28)
0026
0027def nowarning_tempnam(*args, **kw):
0028    return os.tempnam(*args, **kw)
0029
0030class SQLObjectVersionTable(sqlobject.SQLObject):
0031    """
0032    This table is used to store information about the database and
0033    its version (used with record and update commands).
0034    """
0035    class sqlmeta:
0036        table = 'sqlobject_db_version'
0037    version = col.StringCol()
0038    updated = col.DateTimeCol(default=col.DateTimeCol.now)
0039
0040def db_differences(soClass, conn):
0041    """
0042    Returns the differences between a class and the table in a
0043    connection.  Returns [] if no differences are found.  This
0044    function does the best it can; it can miss many differences.
0045    """
0046    # @@: Repeats a lot from CommandStatus.command, but it's hard
0047    # to actually factor out the display logic.  Or I'm too lazy
0048    # to do so.
0049    diffs = []
0050    if not conn.tableExists(soClass.sqlmeta.table):
0051        if soClass.sqlmeta.columns:
0052            diffs.append('Does not exist in database')
0053    else:
0054        try:
0055            columns = conn.columnsFromSchema(soClass.sqlmeta.table,
0056                                             soClass)
0057        except AttributeError:
0058            # Database does not support reading columns
0059            pass
0060        else:
0061            existing = {}
0062            for col in columns:
0063                col = col.withClass(soClass)
0064                existing[col.dbName] = col
0065            missing = {}
0066            for col in soClass.sqlmeta.columnList:
0067                if existing.has_key(col.dbName):
0068                    del existing[col.dbName]
0069                else:
0070                    missing[col.dbName] = col
0071            for col in existing.values():
0072                diffs.append('Database has extra column: %s'
0073                             % col.dbName)
0074            for col in missing.values():
0075                diffs.append('Database missing column: %s' % col.dbName)
0076    return diffs
0077
0078class CommandRunner(object):
0079
0080    def __init__(self):
0081        self.commands = {}
0082        self.command_aliases = {}
0083
0084    def run(self, argv):
0085        invoked_as = argv[0]
0086        args = argv[1:]
0087        for i in range(len(args)):
0088            if not args[i].startswith('-'):
0089                # this must be a command
0090                command = args[i].lower()
0091                del args[i]
0092                break
0093        else:
0094            # no command found
0095            self.invalid('No COMMAND given (try "%s help")'
0096                         % os.path.basename(invoked_as))
0097        real_command = self.command_aliases.get(command, command)
0098        if real_command not in self.commands.keys():
0099            self.invalid('COMMAND %s unknown' % command)
0100        runner = self.commands[real_command](
0101            invoked_as, command, args, self)
0102        runner.run()
0103
0104    def register(self, command):
0105        name = command.name
0106        self.commands[name] = command
0107        for alias in command.aliases:
0108            self.command_aliases[alias] = name
0109
0110    def invalid(self, msg, code=2):
0111        print msg
0112        sys.exit(code)
0113
0114the_runner = CommandRunner()
0115register = the_runner.register
0116
0117def standard_parser(connection=True, simulate=True,
0118                    interactive=False, find_modules=True):
0119    parser = optparse.OptionParser()
0120    parser.add_option('-v', '--verbose',
0121                      help='Be verbose (multiple times for more verbosity)',
0122                      action='count',
0123                      dest='verbose',
0124                      default=0)
0125    if simulate:
0126        parser.add_option('-n', '--simulate',
0127                          help="Don't actually do anything (implies -v)",
0128                          action='store_true',
0129                          dest='simulate')
0130    if connection:
0131        parser.add_option('-c', '--connection',
0132                          help="The database connection URI",
0133                          metavar='URI',
0134                          dest='connection_uri')
0135    parser.add_option('-f', '--config-file',
0136                      help="The Paste config file that contains the database URI (in the database key)",
0137                      metavar="FILE",
0138                      dest="config_file")
0139    if find_modules:
0140        parser.add_option('-m', '--module',
0141                          help="Module in which to find SQLObject classes",
0142                          action='append',
0143                          metavar='MODULE',
0144                          dest='modules',
0145                          default=[])
0146        parser.add_option('-p', '--package',
0147                          help="Package to search for SQLObject classes",
0148                          action="append",
0149                          metavar="PACKAGE",
0150                          dest="packages",
0151                          default=[])
0152        parser.add_option('--class',
0153                          help="Select only named classes (wildcards allowed)",
0154                          action="append",
0155                          metavar="NAME",
0156                          dest="class_matchers",
0157                          default=[])
0158    if interactive:
0159        parser.add_option('-i', '--interactive',
0160                          help="Ask before doing anything (use twice to be more careful)",
0161                          action="count",
0162                          dest="interactive",
0163                          default=0)
0164    parser.add_option('--egg',
0165                      help="Select modules from the given Egg, using sqlobject.txt",
0166                      action="append",
0167                      metavar="EGG_SPEC",
0168                      dest="eggs",
0169                      default=[])
0170    return parser
0171
0172class Command(object):
0173
0174    __metaclass__ = DeclarativeMeta
0175
0176    min_args = 0
0177    min_args_error = 'You must provide at least %(min_args)s arguments'
0178    max_args = 0
0179    max_args_error = 'You must provide no more than %(max_args)s arguments'
0180    aliases = ()
0181    required_args = []
0182    description = None
0183
0184    help = ''
0185
0186    def __classinit__(cls, new_args):
0187        if cls.__bases__ == (object,):
0188            # This abstract base class
0189            return
0190        register(cls)
0191
0192    def __init__(self, invoked_as, command_name, args, runner):
0193        self.invoked_as = invoked_as
0194        self.command_name = command_name
0195        self.raw_args = args
0196        self.runner = runner
0197
0198    def run(self):
0199        self.parser.usage = "%%prog [options]\n%s" % self.summary
0200        if self.help:
0201            help = textwrap.fill(
0202                self.help, int(os.environ.get('COLUMNS', 80))-4)
0203            self.parser.usage += '\n' + help
0204        self.parser.prog = '%s %s' % (
0205            os.path.basename(self.invoked_as),
0206            self.command_name)
0207        if self.description:
0208            self.parser.description = description
0209        self.options, self.args = self.parser.parse_args(self.raw_args)
0210        if (getattr(self.options, 'simulate', False)
0211            and not self.options.verbose):
0212            self.options.verbose = 1
0213        if self.min_args is not None and len(self.args) < self.min_args:
0214            self.runner.invalid(
0215                self.min_args_error % {'min_args': self.min_args,
0216                                       'actual_args': len(self.args)})
0217        if self.max_args is not None and len(self.args) > self.max_args:
0218            self.runner.invalid(
0219                self.max_args_error % {'max_args': self.max_args,
0220                                       'actual_args': len(self.args)})
0221        for var_name, option_name in self.required_args:
0222            if not getattr(self.options, var_name, None):
0223                self.runner.invalid(
0224                    'You must provide the option %s' % option_name)
0225        conf = self.config()
0226        if conf and conf.get('sys_path'):
0227            update_sys_path(conf['sys_path'], self.options.verbose)
0228        if conf and conf.get('database'):
0229            conn = sqlobject.connectionForURI(conf['database'])
0230            sqlobject.sqlhub.processConnection = conn
0231        for egg_spec in getattr(self.options, 'eggs', []):
0232            self.load_options_from_egg(egg_spec)
0233        self.command()
0234
0235    def classes(self, require_connection=True,
0236                require_some=False):
0237        all = []
0238        conf = self.config()
0239        for module_name in self.options.modules:
0240            all.extend(self.classes_from_module(
0241                moduleloader.load_module(module_name)))
0242        for package_name in self.options.packages:
0243            all.extend(self.classes_from_package(package_name))
0244        for egg_spec in self.options.eggs:
0245            all.extend(self.classes_from_egg(egg_spec))
0246        if self.options.class_matchers:
0247            filtered = []
0248            for soClass in all:
0249                name = soClass.__name__
0250                for matcher in self.options.class_matchers:
0251                    if fnmatch.fnmatch(name, matcher):
0252                        filtered.append(soClass)
0253                        break
0254            all = filtered
0255        conn = self.connection()
0256        if conn:
0257            for soClass in all:
0258                soClass._connection = conn
0259        else:
0260            missing = []
0261            for soClass in all:
0262                try:
0263                    if not soClass._connection:
0264                        missing.append(soClass)
0265                except AttributeError:
0266                    missing.append(soClass)
0267            if missing and require_connection:
0268                self.runner.invalid(
0269                    'These classes do not have connections set:\n  * %s\n'
0270                    'You must indicate --connection=URI'
0271                    % '\n  * '.join([soClass.__name__
0272                                     for soClass in missing]))
0273        if require_some and not all:
0274            print 'No classes found!'
0275            if self.options.modules:
0276                print 'Looked in modules: %s' % ', '.join(self.options.modules)
0277            else:
0278                print 'No modules specified'
0279            if self.options.packages:
0280                print 'Looked in packages: %s' % ', '.join(self.options.packages)
0281            else:
0282                print 'No packages specified'
0283            if self.options.class_matchers:
0284                print 'Matching class pattern: %s' % self.options.class_matches
0285            if self.options.eggs:
0286                print 'Looked in eggs: %s' % ', '.join(self.options.eggs)
0287            else:
0288                print 'No eggs specified'
0289            sys.exit(1)
0290        return all
0291
0292    def classes_from_module(self, module):
0293        all = []
0294        if hasattr(module, 'soClasses'):
0295            for name_or_class in module.soClasses:
0296                if isinstance(name_or_class, str):
0297                    name_or_class = getattr(module, name_or_class)
0298                all.append(name_or_class)
0299        else:
0300            for name in dir(module):
0301                value = getattr(module, name)
0302                if (isinstance(value, type)
0303                    and issubclass(value, sqlobject.SQLObject)
0304                    and value.__module__ == module.__name__):
0305                    all.append(value)
0306        return all
0307
0308    def connection(self):
0309        config = self.config()
0310        if config is not None:
0311            assert config.get('database'), (
0312                "No database variable found in config file %s"
0313                % self.options.config_file)
0314            return sqlobject.connectionForURI(config['database'])
0315        elif getattr(self.options, 'connection_uri', None):
0316            return sqlobject.connectionForURI(self.options.connection_uri)
0317        else:
0318            return None
0319
0320    def config(self):
0321        if not getattr(self.options, 'config_file', None):
0322            return None
0323        if pyconfig and self.options.config_fn.endswith('.conf'):
0324            config = pyconfig.Config(with_default=True)
0325            config.load(self.options.config_file)
0326            CONFIG.push_process_config(config)
0327            return config
0328        else:
0329            return self.ini_config(self.options.config_file)
0330
0331    def ini_config(self, conf_fn):
0332        conf_section = 'main'
0333        if '#' in conf_fn:
0334            conf_fn, conf_section = conf_fn.split('#', 1)
0335
0336        from ConfigParser import ConfigParser
0337        p = ConfigParser()
0338        # Case-sensitive:
0339        p.optionxform = str
0340        if not os.path.exists(conf_fn):
0341            # Stupid RawConfigParser doesn't give an error for
0342            # non-existant files:
0343            raise OSError(
0344                "Config file %s does not exist" % self.options.config_file)
0345        p.read([conf_fn])
0346        p._defaults.setdefault(
0347            'here', os.path.dirname(os.path.abspath(conf_fn)))
0348
0349        possible_sections = []
0350        for section in p.sections():
0351            name = section.strip().lower()
0352            if (conf_section == name or
0353                (conf_section == name.split(':')[-1]
0354                 and name.split(':')[0] in ('app', 'application'))):
0355                possible_sections.append(section)
0356
0357        if not possible_sections:
0358            raise OSError(
0359                "Config file %s does not have a section [%s] or [*:%s]"
0360                % (conf_fn, conf_section, conf_section))
0361        if len(possible_sections) > 1:
0362            raise OSError(
0363                "Config file %s has multiple sections matching %s: %s"
0364                % (conf_fn, conf_section, ', '.join(possible_sections)))
0365
0366        config = {}
0367        for op in p.options(possible_sections[0]):
0368            config[op] = p.get(possible_sections[0], op)
0369        return config
0370
0371    def classes_from_package(self, package_name):
0372        all = []
0373        package = moduleloader.load_module(package_name)
0374        package_dir = os.path.dirname(package.__file__)
0375
0376        def find_classes_in_file(arg, dir_name, filenames):
0377            if dir_name.startswith('.svn'):
0378                return
0379            filenames = filter(lambda fname: fname.endswith('.py') and fname != '__init__.py',
0380                               filenames)
0381            for fname in filenames:
0382                module_name = os.path.join(dir_name, fname)
0383                module_name = module_name[module_name.find(package_name):]
0384                module_name = module_name.replace(os.path.sep,'.')[:-3]
0385                try:
0386                    module = moduleloader.load_module(module_name)
0387                except ImportError, err:
0388                    if self.options.verbose:
0389                        print 'Could not import module "%s". Error was : "%s"' % (module_name, err)
0390                    continue
0391                except Exception, exc:
0392                    if self.options.verbose:
0393                        print 'Unknown exception while processing module "%s" : "%s"' % (module_name, exc)
0394                    continue
0395                classes = self.classes_from_module(module)
0396                all.extend(classes)
0397
0398        os.path.walk(package_dir, find_classes_in_file, None)
0399        return all
0400
0401    def classes_from_egg(self, egg_spec):
0402        modules = []
0403        dist, conf = self.config_from_egg(egg_spec, warn_no_sqlobject=True)
0404        for mod in conf.get('db_module', '').split(','):
0405            mod = mod.strip()
0406            if not mod:
0407                continue
0408            if self.options.verbose:
0409                print 'Looking in module %s' % mod
0410            modules.extend(self.classes_from_module(
0411                moduleloader.load_module(mod)))
0412        return modules
0413
0414    def load_options_from_egg(self, egg_spec):
0415        dist, conf = self.config_from_egg(egg_spec)
0416        if (hasattr(self.options, 'output_dir')
0417            and not self.options.output_dir
0418            and conf.get('history_dir')):
0419            dir = conf['history_dir']
0420            dir = dir.replace('$base', dist.location)
0421            self.options.output_dir = dir
0422
0423    def config_from_egg(self, egg_spec, warn_no_sqlobject=True):
0424        import pkg_resources
0425        pkg_resources.require(egg_spec)
0426        dist = pkg_resources.working_set.find(pkg_resources.Requirement(egg_spec))
0427        if not dist.has_metadata('sqlobject.txt'):
0428            if warn_no_sqlobject:
0429                print 'No sqlobject.txt in %s egg info' % egg_spec
0430            return {}
0431        result = {}
0432        for line in dist.get_metadata_lines('sqlobject.txt'):
0433            line = line.strip()
0434            if not line or line.startswith('#'):
0435                continue
0436            name, value = line.split('=', 1)
0437            name = name.strip().lower()
0438            if name in result:
0439                print 'Warning: %s appears more than one in sqlobject.txt' % name
0440            result[name.strip().lower()] = value.strip()
0441        return dist, result
0442
0443    def command(self):
0444        raise NotImplementedError
0445
0446    def _get_prog_name(self):
0447        return os.path.basename(self.invoked_as)
0448    prog_name = property(_get_prog_name)
0449
0450    def ask(self, prompt, safe=False, default=True):
0451        if self.options.interactive >= 2:
0452            default = safe
0453        if default:
0454            prompt += ' [Y/n]? '
0455        else:
0456            prompt += ' [y/N]? '
0457        while 1:
0458            response = raw_input(prompt).strip()
0459            if not response.strip():
0460                return default
0461            if response and response[0].lower() in ('y', 'n'):
0462                return response[0].lower() == 'y'
0463            print 'Y or N please'
0464
0465    def shorten_filename(self, fn):
0466        """
0467        Shortens a filename to make it relative to the current
0468        directory (if it can).  For display purposes.
0469        """
0470        if fn.startswith(os.getcwd() + '/'):
0471            fn = fn[len(os.getcwd())+1:]
0472        return fn
0473
0474    def open_editor(self, pretext, breaker=None, extension='.txt'):
0475        """
0476        Open an editor with the given text.  Return the new text,
0477        or None if no edits were made.  If given, everything after
0478        `breaker` will be ignored.
0479        """
0480        fn = nowarning_tempnam() + extension
0481        f = open(fn, 'w')
0482        f.write(pretext)
0483        f.close()
0484        print '$EDITOR %s' % fn
0485        os.system('$EDITOR %s' % fn)
0486        f = open(fn, 'r')
0487        content = f.read()
0488        f.close()
0489        if breaker:
0490            content = content.split(breaker)[0]
0491            pretext = pretext.split(breaker)[0]
0492        if content == pretext or not content.strip():
0493            return None
0494        return content
0495
0496class CommandSQL(Command):
0497
0498    name = 'sql'
0499    summary = 'Show SQL CREATE statements'
0500
0501    parser = standard_parser(simulate=False)
0502
0503    def command(self):
0504        classes = self.classes()
0505        for cls in classes:
0506            if self.options.verbose >= 1:
0507                print '-- %s from %s' % (
0508                    cls.__name__, cls.__module__)
0509            print cls.createTableSQL().strip() + ';\n'
0510
0511class CommandList(Command):
0512
0513    name = 'list'
0514    summary = 'Show all SQLObject classes found'
0515
0516    parser = standard_parser(simulate=False, connection=False)
0517
0518    def command(self):
0519        if self.options.verbose >= 1:
0520            print 'Classes found:'
0521        classes = self.classes(require_connection=False)
0522        for soClass in classes:
0523            print '%s.%s' % (soClass.__module__, soClass.__name__)
0524            if self.options.verbose >= 1:
0525                print '  Table: %s' % soClass.sqlmeta.table
0526
0527class CommandCreate(Command):
0528
0529    name = 'create'
0530    summary = 'Create tables'
0531
0532    parser = standard_parser(interactive=True)
0533    parser.add_option('--create-db',
0534                      action='store_true',
0535                      dest='create_db',
0536                      help="Create the database")
0537
0538    def command(self):
0539        v = self.options.verbose
0540        created = 0
0541        existing = 0
0542        dbs_created = []
0543        for soClass in self.classes(require_some=True):
0544            if (self.options.create_db
0545                and soClass._connection not in dbs_created):
0546                if not self.options.simulate:
0547                    soClass._connection.createEmptyDatabase()
0548                else:
0549                    print '(simulating; cannot create database)'
0550                dbs_created.append(soClass._connection)
0551            exists = soClass._connection.tableExists(soClass.sqlmeta.table)
0552            if v >= 1:
0553                if exists:
0554                    existing += 1
0555                    print '%s already exists.' % soClass.__name__
0556                else:
0557                    print 'Creating %s' % soClass.__name__
0558            if v >= 2:
0559                print soClass.createTableSQL()
0560            if (not self.options.simulate
0561                and not exists):
0562                if self.options.interactive:
0563                    if self.ask('Create %s' % soClass.__name__):
0564                        created += 1
0565                        soClass.createTable()
0566                    else:
0567                        print 'Cancelled'
0568                else:
0569                    created += 1
0570                    soClass.createTable()
0571        if v >= 1:
0572            print '%i tables created (%i already exist)' % (
0573                created, existing)
0574
0575
0576class CommandDrop(Command):
0577
0578    name = 'drop'
0579    summary = 'Drop tables'
0580
0581    parser = standard_parser(interactive=True)
0582
0583    def command(self):
0584        v = self.options.verbose
0585        dropped = 0
0586        not_existing = 0
0587        for soClass in self.classes():
0588            exists = soClass._connection.tableExists(soClass.sqlmeta.table)
0589            if v >= 1:
0590                if exists:
0591                    print 'Dropping %s' % soClass.__name__
0592                else:
0593                    not_existing += 1
0594                    print '%s does not exist.' % soClass.__name__
0595            if (not self.options.simulate
0596                and exists):
0597                if self.options.interactive:
0598                    if self.ask('Drop %s' % soClass.__name__):
0599                        dropped += 1
0600                        soClass.dropTable()
0601                    else:
0602                        print 'Cancelled'
0603                else:
0604                    dropped += 1
0605                    soClass.dropTable()
0606        if v >= 1:
0607            print '%i tables dropped (%i didn\'t exist)' % (
0608                dropped, not_existing)
0609
0610class CommandStatus(Command):
0611
0612    name = 'status'
0613    summary = 'Show status of classes vs. database'
0614    help = ('This command checks the SQLObject definition and checks if '
0615            'the tables in the database match.  It can always test for '
0616            'missing tables, and on some databases can test for the '
0617            'existance of other tables.  Column types are not currently '
0618            'checked.')
0619
0620    parser = standard_parser(simulate=False)
0621
0622    def print_class(self, soClass):
0623        if self.printed:
0624            return
0625        self.printed = True
0626        print 'Checking %s...' % soClass.__name__
0627
0628    def command(self):
0629        good = 0
0630        bad = 0
0631        missing_tables = 0
0632        columnsFromSchema_warning = False
0633        for soClass in self.classes(require_some=True):
0634            conn = soClass._connection
0635            self.printed = False
0636            if self.options.verbose:
0637                self.print_class(soClass)
0638            if not conn.tableExists(soClass.sqlmeta.table):
0639                self.print_class(soClass)
0640                print '  Does not exist in database'
0641                missing_tables += 1
0642                continue
0643            try:
0644                columns = conn.columnsFromSchema(soClass.sqlmeta.table,
0645                                                 soClass)
0646            except AttributeError:
0647                if not columnsFromSchema_warning:
0648                    print 'Database does not support reading columns'
0649                    columnsFromSchema_warning = True
0650                good += 1
0651                continue
0652            except AssertionError, e:
0653                print 'Cannot read db table %s: %s' % (
0654                    soClass.sqlmeta.table, e)
0655                continue
0656            existing = {}
0657            for col in columns:
0658                col = col.withClass(soClass)
0659                existing[col.dbName] = col
0660            missing = {}
0661            for col in soClass.sqlmeta.columnList:
0662                if existing.has_key(col.dbName):
0663                    del existing[col.dbName]
0664                else:
0665                    missing[col.dbName] = col
0666            if existing:
0667                self.print_class(soClass)
0668                for col in existing.values():
0669                    print '  Database has extra column: %s' % col.dbName
0670            if missing:
0671                self.print_class(soClass)
0672                for col in missing.values():
0673                    print '  Database missing column: %s' % col.dbName
0674            if existing or missing:
0675                bad += 1
0676            else:
0677                good += 1
0678        if self.options.verbose:
0679            print '%i in sync; %i out of sync; %i not in database' % (
0680                good, bad, missing_tables)
0681
0682class CommandHelp(Command):
0683
0684    name = 'help'
0685    summary = 'Show help'
0686
0687    parser = optparse.OptionParser()
0688
0689    max_args = 1
0690
0691    def command(self):
0692        if self.args:
0693            the_runner.run([self.invoked_as, self.args[0], '-h'])
0694        else:
0695            print 'Available commands:'
0696            print '  (use "%s help COMMAND" or "%s COMMAND -h" ' % (
0697                self.prog_name, self.prog_name)
0698            print '  for more information)'
0699            items = the_runner.commands.items()
0700            items.sort()
0701            max_len = max([len(cn) for cn, c in items])
0702            for command_name, command in items:
0703                print '%s:%s %s' % (command_name,
0704                                    ' '*(max_len-len(command_name)),
0705                                    command.summary)
0706                if command.aliases:
0707                    print '%s (Aliases: %s)' % (
0708                        ' '*max_len, ', '.join(command.aliases))
0709
0710class CommandExecute(Command):
0711
0712    name = 'execute'
0713    summary = 'Execute SQL statements'
0714    help = ('Runs SQL statements directly in the database, with no '
0715            'intervention.  Useful when used with a configuration file.  '
0716            'Each argument is executed as an individual statement.')
0717
0718    parser = standard_parser(find_modules=False)
0719    parser.add_option('--stdin',
0720                      help="Read SQL from stdin (normally takes SQL from the command line)",
0721                      dest="use_stdin",
0722                      action="store_true")
0723
0724    max_args = None
0725
0726    def command(self):
0727        args = self.args
0728        if self.options.use_stdin:
0729            if self.options.verbose:
0730                print "Reading additional SQL from stdin (Ctrl-D or Ctrl-Z to finish)..."
0731            args.append(sys.stdin.read())
0732        self.conn = self.connection().getConnection()
0733        self.cursor = self.conn.cursor()
0734        for sql in args:
0735            self.execute_sql(sql)
0736
0737    def execute_sql(self, sql):
0738        if self.options.verbose:
0739            print sql
0740        try:
0741            self.cursor.execute(sql)
0742        except Exception, e:
0743            if not self.options.verbose:
0744                print sql
0745            print "****Error:"
0746            print '    ', e
0747            return
0748        desc = self.cursor.description
0749        rows = self.cursor.fetchall()
0750        if self.options.verbose:
0751            if not self.cursor.rowcount:
0752                print "No rows accessed"
0753            else:
0754                print "%i rows accessed" % self.cursor.rowcount
0755        if desc:
0756            for name, type_code, display_size, internal_size, precision, scale, null_ok in desc:
0757                sys.stdout.write("%s\t" % name)
0758            sys.stdout.write("\n")
0759        for row in rows:
0760            for col in row:
0761                sys.stdout.write("%r\t" % col)
0762            sys.stdout.write("\n")
0763        print
0764
0765class CommandRecord(Command):
0766
0767    name = 'record'
0768    summary = 'Record historical information about the database status'
0769    help = ('Record state of table definitions.  The state of each '
0770            'table is written out to a separate file in a directory, '
0771            'and that directory forms a "version".  A table is also '
0772            'added to you datebase (%s) that reflects the version the '
0773            'database is currently at.  Use the upgrade command to '
0774            'sync databases with code.'
0775            % SQLObjectVersionTable.sqlmeta.table)
0776
0777    parser = standard_parser()
0778    parser.add_option('--output-dir',
0779                      help="Base directory for recorded definitions",
0780                      dest="output_dir",
0781                      metavar="DIR",
0782                      default=None)
0783    parser.add_option('--no-db-record',
0784                      help="Don't record version to database",
0785                      dest="db_record",
0786                      action="store_false",
0787                      default=True)
0788    parser.add_option('--force-create',
0789                      help="Create a new version even if appears to be "
0790                      "identical to the last version",
0791                      action="store_true",
0792                      dest="force_create")
0793    parser.add_option('--name',
0794                      help="The name to append to the version.  The "
0795                      "version should sort after previous versions (so "
0796                      "any versions from the same day should come "
0797                      "alphabetically before this version).",
0798                      dest="version_name",
0799                      metavar="NAME")
0800    parser.add_option('--force-db-version',
0801                      help="Update the database version, and include no "
0802                      "database information.  This is for databases that "
0803                      "were developed without any interaction with "
0804                      "this tool, to create a 'beginning' revision.",
0805                      metavar="VERSION_NAME",
0806                      dest="force_db_version")
0807    parser.add_option('--edit',
0808                      help="Open an editor for the upgrader in the last "
0809                      "version (using $EDITOR).",
0810                      action="store_true",
0811                      dest="open_editor")
0812
0813    version_regex = re.compile(r'^\d\d\d\d-\d\d-\d\d')
0814
0815    def command(self):
0816        if self.options.force_db_version:
0817            self.command_force_db_version()
0818            return
0819
0820        v = self.options.verbose
0821        sim = self.options.simulate
0822        classes = self.classes()
0823        if not classes:
0824            print "No classes found!"
0825            return
0826
0827        output_dir = self.find_output_dir()
0828        version = os.path.basename(output_dir)
0829        print "Creating version %s" % version
0830        conns = []
0831        files = {}
0832        for cls in self.classes():
0833            dbName = cls._connection.dbName
0834            if cls._connection not in conns:
0835                conns.append(cls._connection)
0836            fn = os.path.join(cls.__name__
0837                              + '_' + dbName + '.sql')
0838            if sim:
0839                continue
0840            files[fn] = ''.join([
0841                '-- Exported definition from %s\n'
0842                % time.strftime('%Y-%m-%dT%H:%M:%S'),
0843                '-- Class %s.%s\n'
0844                % (cls.__module__, cls.__name__),
0845                '-- Database: %s\n'
0846                % dbName,
0847                cls.createTableSQL().strip(),
0848                '\n'])
0849        last_version_dir = self.find_last_version()
0850        if last_version_dir and not self.options.force_create:
0851            if v > 1:
0852                print "Checking %s to see if it is current" % last_version_dir
0853            files_copy = files.copy()
0854            for fn in os.listdir(last_version_dir):
0855                if not fn.endswith('.sql'):
0856                    continue
0857                if not files_copy.has_key(fn):
0858                    if v > 1:
0859                        print "Missing file %s" % fn
0860                    break
0861                f = open(os.path.join(last_version_dir, fn), 'r')
0862                content = f.read()
0863                f.close()
0864                if (self.strip_comments(files_copy[fn])
0865                    != self.strip_comments(content)):
0866                    if v > 1:
0867                        print "Content does not match: %s" % fn
0868                    break
0869                del files_copy[fn]
0870            else:
0871                # No differences so far
0872                if not files_copy:
0873                    # Used up all files
0874                    print ("Current status matches version %s"
0875                           % os.path.basename(last_version_dir))
0876                    return
0877                if v > 1:
0878                    print "Extra files: %s" % ', '.join(files_copy.keys())
0879            if v:
0880                print ("Current state does not match %s"
0881                       % os.path.basename(last_version_dir))
0882        if v > 1 and not last_version_dir:
0883            print "No last version to check"
0884        if not sim:
0885            os.mkdir(output_dir)
0886        if v:
0887            print 'Making directory %s' % self.shorten_filename(output_dir)
0888        files = files.items()
0889        files.sort()
0890        for fn, content in files:
0891            if v:
0892                print '  Writing %s' % self.shorten_filename(fn)
0893            if not sim:
0894                f = open(os.path.join(output_dir, fn), 'w')
0895                f.write(content)
0896                f.close()
0897        all_diffs = []
0898        for cls in self.classes():
0899            for conn in conns:
0900                diffs = db_differences(cls, conn)
0901                for diff in diffs:
0902                    if len(conns) > 1:
0903                        diff = '  (%s).%s: %s' % (
0904                            conn.uri(), cls.sqlmeta.table, diff)
0905                    else:
0906                        diff = '  %s: %s' % (cls.sqlmeta.table, diff)
0907                    all_diffs.append(diff)
0908        if all_diffs:
0909            print 'Database does not match schema:'
0910            print '\n'.join(all_diffs)
0911            if self.options.db_record:
0912                print '(Not updating database version)'
0913        elif self.options.db_record:
0914            for conn in conns:
0915                self.update_db(version, conn)
0916        if self.options.open_editor:
0917            if not last_version_dir:
0918                print ("Cannot edit upgrader because there is no "
0919                       "previous version")
0920            else:
0921                breaker = ('-'*20 + ' lines below this will be ignored '
0922                           + '-'*20)
0923                pre_text = breaker + '\n' + '\n'.join(all_diffs)
0924                text = self.open_editor('\n\n' + pre_text, breaker=breaker,
0925                                        extension='.sql')
0926                if text is not None:
0927                    fn = os.path.join(last_version_dir,
0928                                      'upgrade_%s_%s.sql' %
0929                                      (dbName, version))
0930                    f = open(fn, 'w')
0931                    f.write(text)
0932                    f.close()
0933                    print 'Wrote to %s' % fn
0934
0935    def update_db(self, version, conn):
0936        v = self.options.verbose
0937        if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
0938            if v:
0939                print ('Creating table %s'
0940                       % SQLObjectVersionTable.sqlmeta.table)
0941            sql = SQLObjectVersionTable.createTableSQL(connection=conn)
0942            if v > 1:
0943                print sql
0944            if not self.options.simulate:
0945                SQLObjectVersionTable.createTable(connection=conn)
0946        if not self.options.simulate:
0947            SQLObjectVersionTable.clearTable(connection=conn)
0948            SQLObjectVersionTable(
0949                version=version,
0950                connection=conn)
0951
0952    def strip_comments(self, sql):
0953        lines = [l for l in sql.splitlines()
0954                 if not l.strip().startswith('--')]
0955        return '\n'.join(lines)
0956
0957    def base_dir(self):
0958        base = self.options.output_dir
0959        if base is None:
0960            base = CONFIG.get('sqlobject_history_dir', '.')
0961        if not os.path.exists(base):
0962            print 'Creating history directory %s' % self.shorten_filename(base)
0963            if not self.options.simulate:
0964                os.makedirs(base)
0965        return base
0966
0967    def find_output_dir(self):
0968        today = time.strftime('%Y-%m-%d', time.localtime())
0969        if self.options.version_name:
0970            dir = os.path.join(self.base_dir(), today + '-' +
0971                               self.options.version_name)
0972            if os.path.exists(dir):
0973                print ("Error, directory already exists: %s"
0974                       % dir)
0975                sys.exit(1)
0976            return dir
0977        extra = ''
0978        while 1:
0979            dir = os.path.join(self.base_dir(), today + extra)
0980            if not os.path.exists(dir):
0981                return dir
0982            if not extra:
0983                extra = 'a'
0984            else:
0985                extra = chr(ord(extra)+1)
0986
0987    def find_last_version(self):
0988        names = []
0989        for fn in os.listdir(self.base_dir()):
0990            if not self.version_regex.search(fn):
0991                continue
0992            names.append(fn)
0993        if not names:
0994            return None
0995        names.sort()
0996        return os.path.join(self.base_dir(), names[-1])
0997
0998    def command_force_db_version(self):
0999        v = self.options.verbose
1000        sim = self.options.simulate
1001        version = self.options.force_db_version
1002        if not self.version_regex.search(version):
1003            print "Versions must be in the format YYYY-MM-DD..."
1004            print "You version %s does not fit this" % version
1005            return
1006        version_dir = os.path.join(self.base_dir(), version)
1007        if not os.path.exists(version_dir):
1008            if v:
1009                print 'Creating %s' % self.shorten_filename(version_dir)
1010            if not sim:
1011                os.mkdir(version_dir)
1012        elif v:
1013            print ('Directory %s exists'
1014                   % self.shorten_filename(version_dir))
1015        if self.options.db_record:
1016            self.update_db(version, self.connection())
1017
1018class CommandUpgrade(CommandRecord):
1019
1020    name = 'upgrade'
1021    summary = 'Update the database to a new version (as created by record)'
1022    help = ('This command runs scripts (that you write by hand) to '
1023            'upgrade a database.  The database\'s current version is in '
1024            'the sqlobject_version table (use record --force-db-version '
1025            'if a database does not have a sqlobject_version table), '
1026            'and upgrade scripts are in the version directory you are '
1027            'upgrading FROM, named upgrade_DBNAME_VERSION.sql, like '
1028            '"upgrade_mysql_2004-12-01b.sql".')
1029
1030    parser = standard_parser(find_modules=False)
1031    parser.add_option('--upgrade-to',
1032                      help="Upgrade to the given version (default: newest version)",
1033                      dest="upgrade_to",
1034                      metavar="VERSION")
1035    parser.add_option('--output-dir',
1036                      help="Base directory for recorded definitions",
1037                      dest="output_dir",
1038                      metavar="DIR",
1039                      default=None)
1040
1041    upgrade_regex = re.compile(r'^upgrade_([a-z]*)_([^.]*)\.sql$', re.I)
1042
1043    def command(self):
1044        v = self.options.verbose
1045        sim = self.options.simulate
1046        if self.options.upgrade_to:
1047            version_to = self.options.upgrade_to
1048        else:
1049            version_to = os.path.basename(self.find_last_version())
1050        current = self.current_version()
1051        if v:
1052            print 'Current version: %s' % current
1053        version_list = self.make_plan(current, version_to)
1054        if not version_list:
1055            print 'Database up to date'
1056            return
1057        if v:
1058            print 'Plan:'
1059            for next_version, upgrader in version_list:
1060                print '  Use %s to upgrade to %s' % (
1061                    self.shorten_filename(upgrader), next_version)
1062        conn = self.connection()
1063        for next_version, upgrader in version_list:
1064            f = open(upgrader)
1065            sql = f.read()
1066            f.close()
1067            if v:
1068                print "Running:"
1069                print sql
1070                print '-'*60
1071            if not sim:
1072                try:
1073                    conn.query(sql)
1074                except:
1075                    print "Error in script: %s" % upgrader
1076                    raise
1077            self.update_db(next_version, conn)
1078        print 'Done.'
1079
1080
1081    def current_version(self):
1082        conn = self.connection()
1083        if not conn.tableExists(SQLObjectVersionTable.sqlmeta.table):
1084            print 'No sqlobject_version table!'
1085            sys.exit(1)
1086        versions = list(SQLObjectVersionTable.select(connection=conn))
1087        if not versions:
1088            print 'No rows in sqlobject_version!'
1089            sys.exit(1)
1090        if len(versions) > 1:
1091            print 'Ambiguous sqlobject_version_table'
1092            sys.exit(1)
1093        return versions[0].version
1094
1095    def make_plan(self, current, dest):
1096        if current == dest:
1097            return []
1098        dbname = self.connection().dbName
1099        next_version, upgrader = self.best_upgrade(current, dest, dbname)
1100        if not upgrader:
1101            print 'No way to upgrade from %s to %s' % (current, dest)
1102            print ('(you need a %s/upgrade_%s_%s.sql script)'
1103                   % (current, dbname, dest))
1104            sys.exit(1)
1105        plan = [(next_version, upgrader)]
1106        if next_version == dest:
1107            return plan
1108        else:
1109            return plan + self.make_plan(next_version, dest)
1110
1111    def best_upgrade(self, current, dest, target_dbname):
1112        current_dir = os.path.join(self.base_dir(), current)
1113        if self.options.verbose > 1:
1114            print ('Looking in %s for upgraders'
1115                   % self.shorten_filename(current_dir))
1116        upgraders = []
1117        for fn in os.listdir(current_dir):
1118            match = self.upgrade_regex.search(fn)
1119            if not match:
1120                if self.options.verbose > 1:
1121                    print 'Not an upgrade script: %s' % fn
1122                continue
1123            dbname = match.group(1)
1124            version = match.group(2)
1125            if dbname != target_dbname:
1126                if self.options.verbose > 1:
1127                    print 'Not for this database: %s (want %s)' % (
1128                        dbname, target_dbname)
1129                continue
1130            if version > dest:
1131                if self.options.verbose > 1:
1132                    print 'Version too new: %s (only want %s)' % (
1133                        version, dest)
1134            upgraders.append((version, os.path.join(current_dir, fn)))
1135        if not upgraders:
1136            if self.options.verbose > 1:
1137                print 'No upgraders found in %s' % current_dir
1138            return None, None
1139        upgraders.sort()
1140        return upgraders[-1]
1141
1142def update_sys_path(paths, verbose):
1143    if isinstance(paths, (str, unicode)):
1144        paths = [paths]
1145    for path in paths:
1146        path = os.path.abspath(path)
1147        if path not in sys.path:
1148            if verbose > 1:
1149                print 'Adding %s to path' % path
1150            sys.path.insert(0, path)
1151
1152if __name__ == '__main__':
1153    the_runner.run(sys.argv)