From 50ce0006a5d8a12297465643d3a052081e35178f Mon Sep 17 00:00:00 2001 From: Lele Gaifax Date: Sat, 15 Dec 2007 09:02:41 +0000 Subject: [PATCH] Implemented FBDialect.server_version_info() --- lib/sqlalchemy/databases/firebird.py | 39 ++++++++++++++++++++++++---- test/dialect/firebird.py | 18 ++++++++----- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/lib/sqlalchemy/databases/firebird.py b/lib/sqlalchemy/databases/firebird.py index 11b30f72f4..55e3ffc6af 100644 --- a/lib/sqlalchemy/databases/firebird.py +++ b/lib/sqlalchemy/databases/firebird.py @@ -43,16 +43,22 @@ By default this module is biased toward dialect 3, but you can easily tweak it to handle dialect 1 if needed:: from sqlalchemy import types as sqltypes - from sqlalchemy.databases.firebird import FBCompiler, FBDate, colspecs, ischema_names - - # Change the name of the function ``length`` to use the UDF version - # instead of ``char_length`` - FBCompiler.LENGTH_FUNCTION_NAME = 'strlen' + from sqlalchemy.databases.firebird import FBDate, colspecs, ischema_names # Adjust the mapping of the timestamp kind ischema_names['TIMESTAMP'] = FBDate colspecs[sqltypes.DateTime] = FBDate, +Other aspects may be version-specific. You can use the ``server_version_info()`` method +on the ``FBDialect`` class to do whatever is needed:: + + from sqlalchemy.databases.firebird import FBCompiler + + if engine.dialect.server_version_info(connection) < (2,0): + # Change the name of the function ``length`` to use the UDF version + # instead of ``char_length`` + FBCompiler.LENGTH_FUNCTION_NAME = 'strlen' + .. [#] Well, that is not the whole story, as the client may still ask a different (lower) dialect... @@ -276,6 +282,29 @@ class FBDialect(default.DefaultDialect): def type_descriptor(self, typeobj): return sqltypes.adapt_type(typeobj, colspecs) + def server_version_info(self, connection): + """Get the version of the Firebird server used by a connection. + + Returns a tuple of (`major`, `minor`, `build`), three integers + representing the version of the attached server. + """ + + # This is the simpler approach (the other uses the services api), + # that for backward compatibility reasons returns a string like + # LI-V6.3.3.12981 Firebird 2.0 + # where the first version is a fake one resembling the old + # Interbase signature. This is more than enough for our purposes, + # as this is mainly (only?) used by the testsuite. + + from re import match + + fbconn = connection.connection.connection + version = fbconn.server_version + m = match('\w+-V(\d+)\.(\d+)\.(\d+)\.(\d+) \w+ (\d+)\.(\d+)', version) + if not m: + raise exceptions.AssertionError("Could not determine version from string '%s'" % version) + return tuple([int(x) for x in m.group(5, 6, 4)]) + def _normalize_name(self, name): """Convert the name to lowercase if it is possible""" diff --git a/test/dialect/firebird.py b/test/dialect/firebird.py index f14422eb02..98f0e9e9a5 100644 --- a/test/dialect/firebird.py +++ b/test/dialect/firebird.py @@ -73,14 +73,17 @@ class CompileTest(SQLCompileTest): self.assert_compile(select([func.max(t.c.col1)]), "SELECT max(sometable.col1) FROM sometable") -class StrLenTest(PersistTest): - # On FB the length() function is implemented by an external UDF, - # strlen(). Various SA tests fail because they pass a parameter - # to it, and that does not work (it always results the maximum - # string length the UDF was declared to accept). - # This test checks that at least it works ok in other cases. +class MiscFBTests(PersistTest): + __only_on__ = 'firebird' + def test_strlen(self): + # On FB the length() function is implemented by an external + # UDF, strlen(). Various SA tests fail because they pass a + # parameter to it, and that does not work (it always results + # the maximum string length the UDF was declared to accept). + # This test checks that at least it works ok in other cases. + meta = MetaData(testbase.db) t = Table('t1', meta, Column('id', Integer, Sequence('t1idseq'), primary_key=True), @@ -94,6 +97,9 @@ class StrLenTest(PersistTest): finally: meta.drop_all() + def test_server_version_info(self): + version = testbase.db.dialect.server_version_info(testbase.db.connect()) + assert len(version) == 3, "Got strange version info: %s" % repr(version) if __name__ == '__main__': testbase.main() -- 2.47.3