]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
engine.url cleanups [ticket:742]
authorJason Kirtland <jek@discorporate.us>
Wed, 29 Aug 2007 22:27:45 +0000 (22:27 +0000)
committerJason Kirtland <jek@discorporate.us>
Wed, 29 Aug 2007 22:27:45 +0000 (22:27 +0000)
- translate_connect_args can now take kw args or the classic list
- in-tree dialects updated to supply their overrides as keywords
- tweaked url parsing in the spirit of the #742 patch, more or less

CHANGES
lib/sqlalchemy/databases/access.py
lib/sqlalchemy/databases/firebird.py
lib/sqlalchemy/databases/mssql.py
lib/sqlalchemy/databases/mysql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/url.py

diff --git a/CHANGES b/CHANGES
index ccd2fdbdad1754e138dd1d2ce1140f5a91f7fba7..be5c2746749046f8fccb5455fdec1b2f6a5153fc 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -15,9 +15,10 @@ CHANGES
 
 - Tickets fixed:
 
+  - [ticket:742]
   - [ticket:748]
-  - [ticket:762]
   - [ticket:760]
+  - [ticket:762]
 
 0.4.0beta4
 ----------
index 2d36049ed7b7a7cc4f3319e89398d13f09e7daa6..7552b897d3c2cc57c36b1acd27db78c9c3038569 100644 (file)
@@ -202,14 +202,14 @@ class AccessDialect(default.DefaultDialect):
     dbapi = classmethod(dbapi)
 
     def create_connect_args(self, url):
-        opts = url.translate_connect_args(['host', 'database', 'username', 'password', 'port'])
+        opts = url.translate_connect_args()
         connectors = ["Driver={Microsoft Access Driver (*.mdb)}"]
         connectors.append("Dbq=%s" % opts["database"])
-        user = opts.get("user")
+        user = opts.get("username", None)
         if user:
             connectors.append("UID=%s" % user)
             connectors.append("PWD=%s" % opts.get("password", ""))
-        return [[";".join (connectors)], {}]
+        return [[";".join(connectors)], {}]
 
     def create_execution_context(self, *args, **kwargs):
         return AccessExecutionContext(self, *args, **kwargs)
@@ -273,8 +273,7 @@ class AccessDialect(default.DefaultDialect):
             
         # A fresh DAO connection is opened for each reflection
         # This is necessary, so we get the latest updates
-        opts = connection.engine.url.translate_connect_args(['host', 'database', 'username', 'password', 'port'])
-        dtbs = daoEngine.OpenDatabase(opts['database'])
+        dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
         
         try:
             for tbl in dtbs.TableDefs:
@@ -340,8 +339,7 @@ class AccessDialect(default.DefaultDialect):
     def table_names(self, connection, schema):
         # A fresh DAO connection is opened for each reflection
         # This is necessary, so we get the latest updates
-        opts = connection.engine.url.translate_connect_args(['host', 'database', 'username', 'password', 'port'])
-        dtbs = daoEngine.OpenDatabase(opts['database'])
+        dtbs = daoEngine.OpenDatabase(connection.engine.url.database)
 
         names = [t.Name for t in dtbs.TableDefs if t.Name[:4] != "MSys" and t.Name[:4] <> "~TMP"]
         dtbs.Close()
index 2a9bbb5bdbfcb517ee8e4b70f5bdddfc300fb5c2..a4262d9ca8f4bff79e2b15e6036369215de8aecd 100644 (file)
@@ -116,7 +116,7 @@ class FBDialect(default.DefaultDialect):
     dbapi = classmethod(dbapi)
     
     def create_connect_args(self, url):
-        opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
+        opts = url.translate_connect_args(username='user')
         if opts.get('port'):
             opts['host'] = "%s/%s" % (opts['host'], opts['port'])
             del opts['port']
index 03b276d4a2dd105f3a12b15e0ac73e8b1aba64b1..8a23ce9a38b9fc36549f2474a59100ff085a661b 100644 (file)
@@ -447,7 +447,7 @@ class MSSQLDialect(default.DefaultDialect):
     dbapi = classmethod(dbapi)
     
     def create_connect_args(self, url):
-        opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
+        opts = url.translate_connect_args(username='user')
         opts.update(url.query)
         if 'auto_identity_insert' in opts:
             self.auto_identity_insert = bool(int(opts.pop('auto_identity_insert')))
index af01cf84e3102a3dde8cc382e7bd5f61012583f5..6d5c545783b857a4ae3edfe82916ab35dc4a627d 100644 (file)
@@ -1353,7 +1353,8 @@ class MySQLDialect(default.DefaultDialect):
     dbapi = classmethod(dbapi)
     
     def create_connect_args(self, url):
-        opts = url.translate_connect_args(['host', 'db', 'user', 'passwd', 'port'])
+        opts = url.translate_connect_args(database='db', username='user',
+                                          password='passwd')
         opts.update(url.query)
 
         util.coerce_kw_type(opts, 'compress', bool)
index e2876f1f80622d0442d68978855b1b33c1b552d6..eecdcebbdfd370bc17741d3e13428505fe2670cf 100644 (file)
@@ -240,7 +240,7 @@ class PGDialect(default.DefaultDialect):
     dbapi = classmethod(dbapi)
     
     def create_connect_args(self, url):
-        opts = url.translate_connect_args(['host', 'database', 'user', 'password', 'port'])
+        opts = url.translate_connect_args(username='user')
         if 'port' in opts:
             opts['port'] = int(opts['port'])
         opts.update(url.query)
index 1f6f9eff4cd840d0661da0061218eef800e33d57..a3487638c28ad65679f8d0b80546308959fe5aa1 100644 (file)
@@ -101,24 +101,40 @@ class URL(object):
             return dialect
         raise ImportError('unknown database %r' % self.drivername) 
   
-    def translate_connect_args(self, names):
-        """Translate attributes into a dictionary of connection arguments.
-
-        Given a list of argument names corresponding to the URL
-        attributes (`host`, `database`, `username`, `password`,
-        `port`), will assemble the attribute values of this URL into
-        the dictionary using the given names.
+    def translate_connect_args(self, names=[], **kw):
+        """Translate url attributes into a dictionary of connection arguments.
+
+        Returns attributes of this url (`host`, `database`, `username`,
+        `password`, `port`) as a plain dictionary.  The attribute names are
+        used as the keys by default.  Unset or false attributes are omitted
+        from the final dictionary.
+
+        \**kw
+          Optional, alternate key names for url attributes::
+
+            # return 'username' as 'user'
+            username='user'
+
+            # omit 'database'
+            database=None
+          
+        names
+          Deprecated.  A list of key names. Equivalent to the keyword
+          usage, must be provided in the order above.
         """
 
-        a = {}
+        translated = {}
         attribute_names = ['host', 'database', 'username', 'password', 'port']
-        for n in names:
-            sname = attribute_names.pop(0)
-            if n is None:
-                continue
-            if getattr(self, sname, None):
-                a[n] = getattr(self, sname)
-        return a
+        for sname in attribute_names:
+            if names:
+                name = names.pop(0)
+            elif sname in kw:
+                name = kw[sname]
+            else:
+                name = sname
+            if name is not None and getattr(self, sname, False):
+                translated[name] = getattr(self, sname)
+        return translated
 
 def make_url(name_or_url):
     """Given a string or unicode instance, produce a new URL instance.
@@ -134,36 +150,40 @@ def make_url(name_or_url):
 
 def _parse_rfc1738_args(name):
     pattern = re.compile(r'''
-            (\w+)://
+            (?P<name>\w+)://
             (?:
-                ([^:/]*)
-                (?::([^/]*))?
+                (?P<username>[^:/]*)
+                (?::(?P<password>[^/]*))?
             @)?
             (?:
-                ([^/:]*)
-                (?::([^/]*))?
+                (?P<host>[^/:]*)
+                (?::(?P<port>[^/]*))?
             )?
-            (?:/(.*))?
+            (?:/(?P<database>.*))?
             '''
             , re.X)
 
     m = pattern.match(name)
     if m is not None:
-        (name, username, password, host, port, database) = m.group(1, 2, 3, 4, 5, 6)
-        if database is not None:
-            tokens = database.split(r"?", 2)
-            database = tokens[0]
-            query = (len(tokens) > 1 and dict( cgi.parse_qsl(tokens[1]) ) or None)
+        components = m.groupdict()
+        if components['database'] is not None:
+            tokens = components['database'].split('?', 2)
+            components['database'] = tokens[0]
+            query = (len(tokens) > 1 and dict(cgi.parse_qsl(tokens[1]))) or None
             if query is not None:
                 query = dict([(k.encode('ascii'), query[k]) for k in query])
         else:
             query = None
-        opts = {'username':username,'password':password,'host':host,'port':port,'database':database, 'query':query}
-        if opts['password'] is not None:
-            opts['password'] = urllib.unquote_plus(opts['password'])
-        return URL(name, **opts)
+        components['query'] = query
+
+        if components['password'] is not None:
+            components['password'] = urllib.unquote_plus(components['password'])
+
+        name = components.pop('name')
+        return URL(name, **components)
     else:
-        raise exceptions.ArgumentError("Could not parse rfc1738 URL from string '%s'" % name)
+        raise exceptions.ArgumentError(
+            "Could not parse rfc1738 URL from string '%s'" % name)
 
 def _parse_keyvalue_args(name):
     m = re.match( r'(\w+)://(.*)', name)