]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Cleaned up the create_connect_args so that it makes no expectations about keys. Fixes...
authorMichael Trier <mtrier@gmail.com>
Sun, 19 Oct 2008 01:18:15 +0000 (01:18 +0000)
committerMichael Trier <mtrier@gmail.com>
Sun, 19 Oct 2008 01:18:15 +0000 (01:18 +0000)
lib/sqlalchemy/databases/mssql.py
test/dialect/mssql.py

index 42743870a0a42e0b48d6c264aea407992bde4cb4..1ff482cf56c5d1b10ed28ae867a29b2e6c350956 100644 (file)
@@ -443,6 +443,26 @@ class MSSQLDialect(default.DefaultDialect):
                 raise ImportError('No DBAPI module detected for MSSQL - please install pyodbc, pymssql, or adodbapi')
     dbapi = classmethod(dbapi)
 
+    def server_version_info(self, connection):
+        """A tuple of the database server version.
+
+        Formats the remote server version as a tuple of version values,
+        e.g. ``(9, 0, 1399)``.  If there are strings in the version number
+        they will be in the tuple too, so don't count on these all being
+        ``int`` values.
+
+        This is a fast check that does not require a round trip.  It is also
+        cached per-Connection.
+        """
+        return connection.dialect._server_version_info(connection.connection)
+    server_version_info = base.connection_memoize(
+        ('mssql', 'server_version_info'))(server_version_info)
+
+    def _server_version_info(self, dbapi_con):
+        """Return a tuple of the database's version number."""
+
+        raise NotImplementedError()
+    
     def create_connect_args(self, url):
         opts = url.translate_connect_args(username='user')
         opts.update(url.query)
@@ -772,18 +792,18 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
         if 'max_identifier_length' in keys:
             self.max_identifier_length = int(keys.pop('max_identifier_length'))
         if 'dsn' in keys:
-            connectors = ['dsn=%s' % keys['dsn']]
+            connectors = ['dsn=%s' % keys.pop('dsn')]
         else:
             connectors = ["DRIVER={%s}" % keys.pop('driver', 'SQL Server'),
-                          'Server=%s' % keys['host'],
-                          'Database=%s' % keys['database'] ]
+                          'Server=%s' % keys.pop('host', ''),
+                          'Database=%s' % keys.pop('database', '') ]
             if 'port' in keys:
-                connectors.append('Port=%d' % int(keys['port']))
+                connectors.append('Port=%d' % int(keys.pop('port')))
         
-        user = keys.get("user")
+        user = keys.pop("user", None)
         if user:
             connectors.append("UID=%s" % user)
-            connectors.append("PWD=%s" % keys.get("password", ""))
+            connectors.append("PWD=%s" % keys.pop('password', ''))
         else:
             connectors.append("TrustedConnection=Yes")
 
@@ -791,7 +811,7 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
         # textual data from your database encoding to your client encoding 
         # This should obviously be set to 'No' if you query a cp1253 encoded 
         # database from a latin1 client... 
-        if 'odbc_autotranslate' in keys: 
+        if 'odbc_autotranslate' in keys:
             connectors.append("AutoTranslate=%s" % keys.pop("odbc_autotranslate"))
 
         # Allow specification of partial ODBC connect string
@@ -800,7 +820,7 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
             if odbc_options[0]=="'" and odbc_options[-1]=="'":
                 odbc_options=odbc_options[1:-1]
             connectors.append(odbc_options)
-        
+        connectors.extend(['%s=%s' % (k,v) for k,v in keys.iteritems()])
         return [[";".join (connectors)], {}]
 
     def is_disconnect(self, e):
@@ -828,6 +848,18 @@ class MSSQLDialect_pyodbc(MSSQLDialect):
                     cursor.nextset()
             context._last_inserted_ids = [int(row[0])]
 
+    def _server_version_info(self, dbapi_con):
+        """Convert a pyodbc SQL_DBMS_VER string into a tuple."""
+
+        version = []
+        r = re.compile('[.\-]')
+        for n in r.split(dbapi_con.getinfo(self.dbapi.SQL_DBMS_VER)):
+            try:
+                version.append(int(n))
+            except ValueError:
+                version.append(n)
+        return tuple(version)
+
 class MSSQLDialect_adodbapi(MSSQLDialect):
     supports_sane_rowcount = True
     supports_sane_multi_rowcount = True
index 02c583d5dfe64680d1a96e00918fa912e02b1bd6..4708cc28c4cfcc93bc2f0775d605adb00f848f18 100755 (executable)
@@ -5,6 +5,7 @@ from sqlalchemy.orm import *
 from sqlalchemy import exc
 from sqlalchemy.sql import table, column
 from sqlalchemy.databases import mssql
+import sqlalchemy.engine.url as url
 from testlib import *
 
 
@@ -362,5 +363,20 @@ class MatchTest(TestBase, AssertsCompiledSQL):
         self.assertEquals([1, 3, 5], [r.id for r in results])
 
 
+class ParseConnectTest(TestBase, AssertsCompiledSQL):
+    __only_on__ = 'mssql'
+
+    def test_pyodbc_connect(self):
+        u = url.make_url('mssql://username:password@hostspec/database')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password'], {}], connection)
+
+    def test_pyodbc_extra_connect(self):
+        u = url.make_url('mssql://username:password@hostspec/database?LANGUAGE=us_english&foo=bar')
+        dialect = mssql.MSSQLDialect_pyodbc()
+        connection = dialect.create_connect_args(u)
+        self.assertEquals([['DRIVER={SQL Server};Server=hostspec;Database=database;UID=username;PWD=password;foo=bar;LANGUAGE=us_english'], {}], connection)
+
 if __name__ == "__main__":
     testenv.main()