]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Implemented FBDialect.server_version_info()
authorLele Gaifax <lele@metapensiero.it>
Sat, 15 Dec 2007 09:02:41 +0000 (09:02 +0000)
committerLele Gaifax <lele@metapensiero.it>
Sat, 15 Dec 2007 09:02:41 +0000 (09:02 +0000)
lib/sqlalchemy/databases/firebird.py
test/dialect/firebird.py

index 11b30f72f494d220e2ffcdddb1a760efd7e757a2..55e3ffc6af858af9e32de19fae24d58beb6796ec 100644 (file)
@@ -43,16 +43,22 @@ By default this module is biased toward dialect 3, but you can easily
 tweak it to handle dialect 1 if needed::
 
   from sqlalchemy import types as sqltypes
-  from sqlalchemy.databases.firebird import FBCompiler, FBDate, colspecs, ischema_names
-
-  # Change the name of the function ``length`` to use the UDF version
-  # instead of ``char_length``
-  FBCompiler.LENGTH_FUNCTION_NAME = 'strlen'
+  from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names
 
   # Adjust the mapping of the timestamp kind
   ischema_names['TIMESTAMP'] = FBDate
   colspecs[sqltypes.DateTime] = FBDate,
 
+Other aspects may be version-specific. You can use the ``server_version_info()`` method
+on the ``FBDialect`` class to do whatever is needed::
+
+  from sqlalchemy.databases.firebird import FBCompiler
+
+  if engine.dialect.server_version_info(connection) < (2,0):
+      # Change the name of the function ``length`` to use the UDF version
+      # instead of ``char_length``
+      FBCompiler.LENGTH_FUNCTION_NAME = 'strlen'
+
 
 .. [#] Well, that is not the whole story, as the client may still ask
        a different (lower) dialect...
@@ -276,6 +282,29 @@ class FBDialect(default.DefaultDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
+    def server_version_info(self, connection):
+        """Get the version of the Firebird server used by a connection.
+
+        Returns a tuple of (`major`, `minor`, `build`), three integers
+        representing the version of the attached server.
+        """
+
+        # This is the simpler approach (the other uses the services api),
+        # that for backward compatibility reasons returns a string like
+        #   LI-V6.3.3.12981 Firebird 2.0
+        # where the first version is a fake one resembling the old
+        # Interbase signature. This is more than enough for our purposes,
+        # as this is mainly (only?) used by the testsuite.
+
+        from re import match
+
+        fbconn = connection.connection.connection
+        version = fbconn.server_version
+        m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version)
+        if not m:
+            raise exceptions.AssertionError("Could not determine version from string '%s'" % version)
+        return tuple([int(x) for x in m.group(5, 6, 4)])
+
     def _normalize_name(self, name):
         """Convert the name to lowercase if it is possible"""
 
index f14422eb02ea6114fcf76627b3671a8701a665f5..98f0e9e9a5ca8e18bbcfe4ec4e651c3c50f49db1 100644 (file)
@@ -73,14 +73,17 @@ class CompileTest(SQLCompileTest):
         self.assert_compile(select([func.max(t.c.col1)]), "SELECT max(sometable.col1) FROM sometable")
 
 
-class StrLenTest(PersistTest):
-    # On FB the length() function is implemented by an external UDF,
-    # strlen().  Various SA tests fail because they pass a parameter
-    # to it, and that does not work (it always results the maximum
-    # string length the UDF was declared to accept).
-    # This test checks that at least it works ok in other cases.
+class MiscFBTests(PersistTest):
 
+    __only_on__ = 'firebird'
+    
     def test_strlen(self):
+        # On FB the length() function is implemented by an external
+        # UDF, strlen().  Various SA tests fail because they pass a
+        # parameter to it, and that does not work (it always results
+        # the maximum string length the UDF was declared to accept).
+        # This test checks that at least it works ok in other cases.
+        
         meta = MetaData(testbase.db)
         t = Table('t1', meta,
             Column('id', Integer, Sequence('t1idseq'), primary_key=True),
@@ -94,6 +97,9 @@ class StrLenTest(PersistTest):
         finally:
             meta.drop_all()
 
+    def test_server_version_info(self):
+        version = testbase.db.dialect.server_version_info(testbase.db.connect())
+        assert len(version) == 3, "Got strange version info: %s" % repr(version)
 
 if __name__ == '__main__':
     testbase.main()