]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
implemented query string support in db urls, gets sent to dialect **kwargs, [ticket...
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Jun 2006 02:03:36 +0000 (02:03 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 1 Jun 2006 02:03:36 +0000 (02:03 +0000)
CHANGES
lib/sqlalchemy/engine/strategies.py
lib/sqlalchemy/engine/url.py
test/parseconnect.py

diff --git a/CHANGES b/CHANGES
index 06917f7689c039de468da70ed1fec0c2391a4478..6dc0aec0b301e3c0a0bb4dd440e9591d639dad7c 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -7,6 +7,8 @@ more unit tests
 - fix to docs, removed incorrect info that close() is unsafe to use
 with threadlocal strategy (its totally safe !)
 - create_engine() can take URLs as string or unicode [ticket:188]
+- firebird support !  thanks to James Ralston and Brad Clements for their
+efforts.
 
 0.2.1
 - "pool" argument to create_engine() properly propigates
index d8b503add196abd59c7935d0787be9d05d5f2e53..e2f5c8b7c585b5a986e1161c4d6df5b3e1496357 100644 (file)
@@ -29,7 +29,9 @@ class PlainEngineStrategy(EngineStrategy):
         u = url.make_url(name_or_url)
         module = u.get_module()
 
-        dialect = module.dialect(**kwargs)
+        args = u.query.copy()
+        args.update(kwargs)
+        dialect = module.dialect(**args)
 
         poolargs = {}
         for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool')):
@@ -41,7 +43,7 @@ class PlainEngineStrategy(EngineStrategy):
         poolargs['use_threadlocal'] = False
         provider = default.PoolConnectionProvider(dialect, u, **poolargs)
 
-        return base.ComposedSQLEngine(provider, dialect, **kwargs)
+        return base.ComposedSQLEngine(provider, dialect, **args)
 PlainEngineStrategy()
 
 class ThreadLocalEngineStrategy(EngineStrategy):
@@ -51,7 +53,9 @@ class ThreadLocalEngineStrategy(EngineStrategy):
         u = url.make_url(name_or_url)
         module = u.get_module()
 
-        dialect = module.dialect(**kwargs)
+        args = u.query.copy()
+        args.update(kwargs)
+        dialect = module.dialect(**args)
 
         poolargs = {}
         for key in (('echo_pool', 'echo'), ('pool_size', 'pool_size'), ('max_overflow', 'max_overflow'), ('poolclass', 'poolclass'), ('pool_timeout','timeout'), ('pool', 'pool')):
@@ -63,7 +67,7 @@ class ThreadLocalEngineStrategy(EngineStrategy):
         poolargs['use_threadlocal'] = True
         provider = threadlocal.TLocalConnectionProvider(dialect, u, **poolargs)
 
-        return threadlocal.TLEngine(provider, dialect, **kwargs)
+        return threadlocal.TLEngine(provider, dialect, **args)
 ThreadLocalEngineStrategy()
 
 
index a4297f5db1aa9ce1b086e1bf978d2b2572ed7f38..5e04e1317b3629c7d9deba7b4af3e9e73fd62298 100644 (file)
@@ -3,13 +3,14 @@ import cgi
 import sqlalchemy.exceptions as exceptions
 
 class URL(object):
-    def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None):
+    def __init__(self, drivername, username=None, password=None, host=None, port=None, database=None, query=None):
         self.drivername = drivername
         self.username = username
         self.password = password
         self.host = host
         self.port = port
         self.database= database
+        self.query = query or {}
     def __str__(self):
         s = self.drivername + "://"
         if self.username is not None:
@@ -23,6 +24,10 @@ class URL(object):
             s += ':' + self.port
         if self.database is not None:
             s += '/' + self.database
+        if len(self.query):
+            keys = self.query.keys()
+            keys.sort()
+            s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
         return s
     def get_module(self):
         return getattr(__import__('sqlalchemy.databases.%s' % self.drivername).databases, self.drivername)
@@ -66,7 +71,13 @@ def _parse_rfc1738_args(name):
     m = pattern.match(name)
     if m is not None:
         (name, username, password, host, port, database) = m.group(1, 2, 3, 4, 5, 6)
-        opts = {'username':username,'password':password,'host':host,'port':port,'database':database}
+        if database is not None:
+            tokens = database.split(r"?", 2)
+            database = tokens[0]
+            query = (len(tokens) > 1 and dict( cgi.parse_qsl(tokens[1]) ) or None)
+        else:
+            query = None
+        opts = {'username':username,'password':password,'host':host,'port':port,'database':database, 'query':query}
         return URL(name, **opts)
     else:
         raise exceptions.ArgumentError("Could not parse rfc1738 URL from string '%s'" % name)
index f53b7e3f7cb3f2f62ee4cb565a42a785a03d9936..43389c272cbc782faeea5fc1c8df7a7717629573 100644 (file)
@@ -13,10 +13,12 @@ class ParseConnectTest(PersistTest):
             'dbtype://username:password@127.0.0.1:1521',
             'dbtype://hostspec/database',
             'dbtype://hostspec',
+            'dbtype://hostspec/?arg1=val1&arg2=val2',
             'dbtype:///database',
             'dbtype:///:memory:',
             'dbtype:///foo/bar/im/a/file',
             'dbtype:///E:/work/src/LEM/db/hello.db',
+            'dbtype:///E:/work/src/LEM/db/hello.db?foo=bar&hoho=lala',
             'dbtype://',
             'dbtype://username:password@/db'
         ):