]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- column labels are now generated in the compilation phase, which
authorMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Mar 2007 19:24:27 +0000 (19:24 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Sat, 24 Mar 2007 19:24:27 +0000 (19:24 +0000)
means their lengths are dialect-dependent.  So on oracle a label
that gets truncated to 30 chars will go out to 63 characters
on postgres.  Also, the true labelname is always attached as the
accessor on the parent Selectable so theres no need to be aware
of the genrerated label names [ticket:512].
- ResultProxy column targeting is greatly simplified, and relies
upon the ANSICompiler's column_labels map to translate the built-in
label on a _ColumnClause (which is now considered to be a unique
identifier of that column) to the label which was generated at compile
time.
- still need to put a baseline of ColumnClause targeting for
ResultProxy objects that originated from a textual query.

CHANGES
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/oracle.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/engine/default.py
lib/sqlalchemy/sql.py
test/sql/alltests.py
test/sql/labels.py [new file with mode: 0644]

diff --git a/CHANGES b/CHANGES
index 02a361d4b4e602c064e75940ca916a37622d32bf..71843f36000275b27c3e5b3c9f2936d79400cdc5 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -1,3 +1,12 @@
+0.3.7
+- sql:
+    - column labels are now generated in the compilation phase, which
+      means their lengths are dialect-dependent.  So on oracle a label
+      that gets truncated to 30 chars will go out to 63 characters
+      on postgres.  Also, the true labelname is always attached as the
+      accessor on the parent Selectable so theres no need to be aware
+      of the genrerated label names [ticket:512].
+      
 0.3.6
 - sql:
     - bindparam() names are now repeatable!  specify two
index ebaedca542d2b97adb37fd63e52b3dc572fce589..0d4fba4e8a1716a1415e3e7ec92552c6251d30c9 100644 (file)
@@ -12,7 +12,7 @@ module.
 
 from sqlalchemy import schema, sql, engine, util, sql_util, exceptions
 from  sqlalchemy.engine import default
-import string, re, sets, weakref
+import string, re, sets, weakref, random
 
 ANSI_FUNCS = sets.ImmutableSet(['CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP',
                                 'CURRENT_USER', 'LOCALTIME', 'LOCALTIMESTAMP',
@@ -125,8 +125,8 @@ class ANSICompiler(sql.Compiled):
         # which will be passed to a ResultProxy and used for resultset-level value conversion
         self.typemap = {}
 
-        # a dictionary of select columns mapped to their name or key
-        self.columns = {}
+        # a dictionary of select columns labels mapped to their "generated" label
+        self.column_labels = {}
 
         # True if this compiled represents an INSERT
         self.isinsert = False
@@ -237,16 +237,22 @@ class ANSICompiler(sql.Compiled):
         return ""
 
     def visit_label(self, label):
+        labelname = label.name
+        if len(labelname) >= self.dialect.max_identifier_length():
+            labelname = labelname[0:self.dialect.max_identifier_length() - 6] + "_" + hex(random.randint(0, 65535))[2:]
+        
         if len(self.select_stack):
-            self.typemap.setdefault(label.name.lower(), label.obj.type)
-        self.strings[label] = self.strings[label.obj] + " AS "  + self.preparer.format_label(label)
-
+            self.typemap.setdefault(labelname.lower(), label.obj.type)
+            if isinstance(label.obj, sql._ColumnClause):
+                self.column_labels[label.obj._label] = labelname.lower()
+        self.strings[label] = self.strings[label.obj] + " AS "  + self.preparer.format_label(label, labelname)
+        
     def visit_column(self, column):
         if len(self.select_stack):
             # if we are within a visit to a Select, set up the "typemap"
             # for this column which is used to translate result set values
             self.typemap.setdefault(column.name.lower(), column.type)
-            self.columns.setdefault(column.key, column)
+            self.column_labels.setdefault(column._label, column.name.lower())
         if column.table is None or not column.table.named_with_column():
             self.strings[column] = self.preparer.format_column(column)
         else:
@@ -1015,8 +1021,8 @@ class ANSIIdentifierPreparer(object):
     def format_sequence(self, sequence):
         return self.__generic_obj_format(sequence, sequence.name)
 
-    def format_label(self, label):
-        return self.__generic_obj_format(label, label.name)
+    def format_label(self, label, name=None):
+        return self.__generic_obj_format(label, name or label.name)
 
     def format_alias(self, alias):
         return self.__generic_obj_format(alias, alias.name)
index d9b85746d67b25547f160c1ce646f2d13c642adb..6141c943f0ae87c2a8216a298424c9d7809b2030 100644 (file)
@@ -192,6 +192,9 @@ class OracleDialect(ansisql.ANSIDialect):
     def type_descriptor(self, typeobj):
         return sqltypes.adapt_type(typeobj, colspecs)
 
+    def max_identifier_length(self):
+        return 30
+        
     def oid_column_name(self, column):
         if not isinstance(column.table, sql.TableClause) and not isinstance(column.table, sql.Select):
             return None
index 93f20889c57678a14d81131981d3b7885e8da177..43d570070f4d215c6b75cfad1c7dc703f4d1ed76 100644 (file)
@@ -279,6 +279,9 @@ class PGDialect(ansisql.ANSIDialect):
     def create_execution_context(self):
         return PGExecutionContext(self)
 
+    def max_identifier_length(self):
+        return 68
+        
     def type_descriptor(self, typeobj):
         if self.version == 2:
             return sqltypes.adapt_type(typeobj, pg2_colspecs)
index 4c9595437888500e2f63634b3c6f7a6bcb2d236b..7f7bde81bb2025a55cda04157f4958004472c210 100644 (file)
@@ -105,6 +105,12 @@ class Dialect(sql.AbstractDialect):
 
         raise NotImplementedError()
 
+    def max_identifier_length(self):
+        """Return the maximum length of identifier names.
+        
+        Return None if no limit."""
+        return None
+
     def supports_sane_rowcount(self):
         """Indicate whether the dialect properly implements statements rowcount.
 
@@ -503,7 +509,7 @@ class Connection(Connectable):
         proxy(str(compiled), parameters)
         context.post_exec(self.__engine, proxy, compiled, parameters)
         rpargs = self.__engine.dialect.create_result_proxy_args(self, cursor)
-        return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, columns=compiled.columns, **rpargs)
+        return ResultProxy(self.__engine, self, cursor, context, typemap=compiled.typemap, column_labels=compiled.column_labels, **rpargs)
 
     # poor man's multimethod/generic function thingy
     executors = {
@@ -803,7 +809,7 @@ class ResultProxy(object):
         else:
             return object.__new__(cls, *args, **kwargs)
 
-    def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, columns=None, should_prefetch=None):
+    def __init__(self, engine, connection, cursor, executioncontext=None, typemap=None, column_labels=None, should_prefetch=None):
         """ResultProxy objects are constructed via the execute() method on SQLEngine."""
 
         self.connection = connection
@@ -811,7 +817,7 @@ class ResultProxy(object):
         self.cursor = cursor
         self.engine = engine
         self.closed = False
-        self.columns = columns
+        self.column_labels = column_labels
         if executioncontext is not None:
             self.__executioncontext = executioncontext
             self.rowcount = executioncontext.get_rowcount(cursor)
@@ -823,6 +829,7 @@ class ResultProxy(object):
         self.props = {}
         self.keys = []
         i = 0
+        
         if metadata is not None:
             for item in metadata:
                 # sqlite possibly prepending table name to colnames so strip
@@ -874,36 +881,21 @@ class ResultProxy(object):
         try:
             return self.__key_cache[key]
         except KeyError:
-            # TODO: use has_key on these, too many potential KeyErrors being raised
-            if isinstance(key, sql.ColumnElement):
-                try:
-                    rec = self.props[key._label.lower()]
-                except KeyError:
-                    try:
-                        rec = self.props[key.key.lower()]
-                    except KeyError:
-                        try:
-                            rec = self.props[key.name.lower()]
-                        except KeyError:
-                            raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % str(key))
-            elif isinstance(key, str):
-                try:
-                    rec = self.props[key.lower()]
-                except KeyError:
-                    try:
-                        if self.columns is not None:
-                            rec = self._convert_key(self.columns[key])
-                        else:
-                            raise
-                    except KeyError:
-                        raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % str(key))
-            else:
-                try:
-                    rec = self.props[key]
-                except KeyError:
-                    raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % str(key))
+            if isinstance(key, int) and key in self.props:
+                rec = self.props[key]
+            elif isinstance(key, basestring) and key.lower() in self.props:
+                rec = self.props[key.lower()]
+            elif isinstance(key, sql.ColumnElement):
+                label = self.column_labels.get(key._label, key.name)
+                if label in self.props:
+                    rec = self.props[label]
+                        
+            if not "rec" in locals():
+                raise exceptions.NoSuchColumnError("Could not locate column in row for column '%s'" % (repr(key)))
+
             self.__key_cache[key] = rec
             return rec
+            
 
     def _has_key(self, row, key):
         try:
index c6e0d9dc4eed7154a6b7ac040485248d878c5f67..798d02d32b7fee44b97a6d049b213e30b5b1bff3 100644 (file)
@@ -48,6 +48,11 @@ class DefaultDialect(base.Dialect):
             typeobj = typeobj()
         return typeobj
 
+    def max_identifier_length(self):
+        # TODO: probably raise this and fill out
+        # db modules better
+        return 30
+        
     def oid_column_name(self, column):
         return None
 
index 8059d95151962aa5136040ec15fa0c5096751c9d..bd018e89cdfc1607d4cb677ea122327a87615236 100644 (file)
@@ -1771,7 +1771,7 @@ class Join(FromClause):
         return [self] + self.onclause._get_from_objects() + self.left._get_from_objects() + self.right._get_from_objects()
 
 class Alias(FromClause):
-    def __init__(self, selectable, alias = None):
+    def __init__(self, selectable, alias=None):
         baseselectable = selectable
         while isinstance(baseselectable, Alias):
             baseselectable = baseselectable.selectable
@@ -1808,6 +1808,7 @@ class Alias(FromClause):
         for c in self.c:
             yield c
         yield self.selectable
+        
     def accept_visitor(self, visitor):
         visitor.visit_alias(self)
 
@@ -1865,6 +1866,13 @@ class _ColumnClause(ColumnElement):
         self.is_literal = is_literal
 
     def _get_label(self):
+        """generate a 'label' for this column.
+        
+        the label is a product of the parent table name and column name, and 
+        is treated as a unique identifier of this Column across all Tables and derived 
+        selectables for a particular metadata collection.
+        """
+        
         # for a "literal" column, we've no idea what the text is
         # therefore no 'label' can be automatically generated
         if self.is_literal:
@@ -1872,8 +1880,10 @@ class _ColumnClause(ColumnElement):
         if self.__label is None:
             if self.table is not None and self.table.named_with_column():
                 self.__label = self.table.name + "_" + self.name
-                if self.table.c.has_key(self.__label) or len(self.__label) >= 30:
-                    self.__label = self.__label[0:24] + "_" + hex(random.randint(0, 65535))[2:]
+                counter = 1
+                while self.table.c.has_key(self.__label):
+                    self.__label = self.__label + "_%d" % counter
+                    counter += 1
             else:
                 self.__label = self.name
             self.__label = "".join([x for x in self.__label if x in legal_characters])
index 2517cdf8d2041a124145c9de3e10584928cde0fb..9f1c0d36eb81637a5987f6a3ec5af8d9eb70a8e5 100644 (file)
@@ -11,6 +11,7 @@ def suite():
         'sql.select',
         'sql.selectable',
         'sql.case_statement', 
+        'sql.labels',
         
         # assorted round-trip tests
         'sql.query',
diff --git a/test/sql/labels.py b/test/sql/labels.py
new file mode 100644 (file)
index 0000000..0b39576
--- /dev/null
@@ -0,0 +1,34 @@
+import testbase
+
+from sqlalchemy import *
+
+class LongLabelsTest(testbase.PersistTest):
+    def setUpAll(self):
+        global metadata, table1
+        metadata = MetaData(engine=testbase.db)
+        table1 = Table("some_large_named_table", metadata,
+            Column("this_is_the_primary_key_column", Integer, primary_key=True),
+            Column("this_is_the_data_column", String(30))
+            )
+        metadata.create_all()
+        table1.insert().execute(**{"this_is_the_primary_key_column":1, "this_is_the_data_column":"data1"})
+        table1.insert().execute(**{"this_is_the_primary_key_column":2, "this_is_the_data_column":"data2"})
+        table1.insert().execute(**{"this_is_the_primary_key_column":3, "this_is_the_data_column":"data3"})
+        table1.insert().execute(**{"this_is_the_primary_key_column":4, "this_is_the_data_column":"data4"})
+    def tearDownAll(self):
+        metadata.drop_all()
+        
+    def test_result(self):
+        r = table1.select(use_labels=True).execute()
+        result = []
+        for row in r:
+            result.append((row[table1.c.this_is_the_primary_key_column], row[table1.c.this_is_the_data_column]))
+        assert result == [
+            (1, "data1"),
+            (2, "data2"),
+            (3, "data3"),
+            (4, "data4"),
+        ]
+    
+if __name__ == '__main__':
+    testbase.main()
\ No newline at end of file