]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
- sqlalchemy.sql.expression.Function is now a public
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Jan 2009 19:45:05 +0000 (19:45 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 2 Jan 2009 19:45:05 +0000 (19:45 +0000)
class.  It can be subclassed to provide user-defined
SQL functions in an imperative style, including
with pre-established behaviors.  The postgis.py
example illustrates one usage of this.

CHANGES
examples/postgis/postgis.py
lib/sqlalchemy/engine/base.py
lib/sqlalchemy/sql/compiler.py
lib/sqlalchemy/sql/expression.py
lib/sqlalchemy/sql/functions.py

diff --git a/CHANGES b/CHANGES
index cdfe4617a038060da7234ff7fb75007a7b34bd6c..e098cbdbc9bb7eaa6a9215d3de42b09835385762 100644 (file)
--- a/CHANGES
+++ b/CHANGES
@@ -186,6 +186,12 @@ CHANGES
 - sql
     - Columns can again contain percent signs within their
       names. [ticket:1256]
+    
+    - sqlalchemy.sql.expression.Function is now a public
+      class.  It can be subclassed to provide user-defined
+      SQL functions in an imperative style, including
+      with pre-established behaviors.  The postgis.py
+      example illustrates one usage of this.
       
     - PickleType now favors == comparison by default,
       if the incoming object (such as a dict) implements
index c463cca26eb07937da95a82972e8570a6945d116..802aa0ea9005f3c25b04013c96e565bb493982b8 100644 (file)
@@ -1,7 +1,7 @@
 """A naive example illustrating techniques to help 
 embed PostGIS functionality.
 
-The techniques here could be used a capable developer 
+The techniques here could be used by a capable developer 
 as the basis for a comprehensive PostGIS SQLAlchemy extension.
 Please note this is an entirely incomplete proof of concept
 only, and PostGIS support is *not* a supported feature 
@@ -40,23 +40,79 @@ from sqlalchemy.orm.properties import ColumnProperty
 from sqlalchemy.types import TypeEngine
 from sqlalchemy.sql import expression
 
+# Python datatypes
+
+class GisElement(object):
+    """Represents a geometry value."""
+
+    @property
+    def wkt(self):
+        return func.AsText(literal(self, Geometry))
+
+    @property
+    def wkb(self):
+        return func.AsBinary(literal(self, Geometry))
+
+    def __str__(self):
+        return self.desc
+
+    def __repr__(self):
+        return "<%s at 0x%x; %r>" % (self.__class__.__name__, id(self), self.desc)
+
+class PersistentGisElement(GisElement):
+    """Represents a Geometry value as loaded from the database."""
+    
+    def __init__(self, desc):
+        self.desc = desc
+
+class TextualGisElement(GisElement, expression.Function):
+    """Represents a Geometry value as expressed within application code; i.e. in wkt format.
+    
+    Extends expression.Function so that the value is interpreted as 
+    GeomFromText(value) in a SQL expression context.
+    
+    """
+    
+    def __init__(self, desc, srid=-1):
+        assert isinstance(desc, basestring)
+        self.desc = desc
+        expression.Function.__init__(self, "GeomFromText", desc, srid)
+
+
+# SQL datatypes.
+
 class Geometry(TypeEngine):
-    """Base PostGIS Geometry column type"""
+    """Base PostGIS Geometry column type.
+    
+    Converts bind/result values to/from a PersistentGisElement.
+    
+    """
     
     name = 'GEOMETRY'
     
-    def __init__(self, dimension, srid=-1):
+    def __init__(self, dimension=None, srid=-1):
         self.dimension = dimension
         self.srid = srid
-
+    
+    def bind_processor(self, dialect):
+        def process(value):
+            if value is not None:
+                return value.desc
+            else:
+                return value
+        return process
+        
     def result_processor(self, dialect):
         def process(value):
             if value is not None:
-                return gis_element(value)
+                return PersistentGisElement(value)
             else:
                 return value
         return process
 
+# other datatypes can be added as needed, which 
+# currently only affect DDL statements.
+
 class Point(Geometry):
     name = 'POINT'
     
@@ -66,10 +122,25 @@ class Curve(Geometry):
 class LineString(Curve):
     name = 'LINESTRING'
 
-# ... add other types as needed
+# ... etc.
+
 
+# DDL integration
 
 class GISDDL(object):
+    """A DDL extension which integrates SQLAlchemy table create/drop 
+    methods with PostGis' AddGeometryColumn/DropGeometryColumn functions.
+    
+    Usage::
+    
+        sometable = Table('sometable', metadata, ...)
+        
+        GISDDL(sometable)
+
+        sometable.create()
+    
+    """
+    
     def __init__(self, table):
         for event in ('before-create', 'after-create', 'before-drop', 'after-drop'):
             table.ddl_listeners[event].append(self)
@@ -95,23 +166,25 @@ class GISDDL(object):
         elif event == 'after-drop':
             table._columns = self._stack.pop()
 
+# ORM integration
+
 def _to_postgis(value):
+    """Interpret a value as a GIS-compatible construct."""
+    
     if hasattr(value, '__clause_element__'):
         return value.__clause_element__()
-    elif isinstance(value, expression.ClauseElement):
+    elif isinstance(value, (expression.ClauseElement, GisElement)):
         return value
     elif isinstance(value, basestring):
-        return func.GeomFromText(value, -1)
-    elif isinstance(value, gis_element):
-        return value.desc
+        return TextualGisElement(value)
     elif value is None:
         return None
     else:
         raise Exception("Invalid type")
-        
+
 
 class GisAttribute(AttributeExtension):
-    """Intercepts 'set' events on a mapped instance and 
+    """Intercepts 'set' events on a mapped instance attribute and 
     converts the incoming value to a GIS expression.
     
     """
@@ -123,44 +196,36 @@ class GisComparator(ColumnProperty.ColumnComparator):
     """Intercepts standard Column operators on mapped class attributes
     and overrides their behavior.
     
-    
     """
     
+    # override the __eq__() operator
     def __eq__(self, other):
         return self.__clause_element__().op('~=')(_to_postgis(other))
 
+    # add a custom operator
     def intersects(self, other):
         return self.__clause_element__().op('&&')(_to_postgis(other))
-    
-class gis_element(object):
-    """Represents a geometry value.
-    
-    This is just the raw string returned by PostGIS, 
-    plus some helper functions.
-    
-    """
-    
-    def __init__(self, desc):
-        self.desc = desc
-    
-    @property
-    def wkt(self):
-        return func.AsText(self.desc)
-
-    @property
-    def wkb(self):
-        return func.AsBinary(self.desc)
-
         
+    # any number of GIS operators can be overridden/added here
+    # using the techniques above.
+        
+
 def GISColumn(*args, **kw):
-    """Define a declarative column property with GIS behavior."""
+    """Define a declarative column property with GIS behavior.
     
+    This just produces orm.column_property() with the appropriate
+    extension and comparator_factory arguments.  The given arguments
+    are passed through to Column.  The declarative module extracts
+    the Column for inclusion in the mapped table.
+    
+    """
     return column_property(
                 Column(*args, **kw), 
                 extension=GisAttribute(), 
                 comparator_factory=GisComparator
             )
-    
+
+# illustrate usage
 if __name__ == '__main__':
     from sqlalchemy import *
     from sqlalchemy.orm import *
@@ -187,8 +252,7 @@ if __name__ == '__main__':
 
     session = sessionmaker(bind=engine)()
     
-    # Add objects using strings for the geometry objects; the attribute extension
-    # converts them to GeomFromText
+    # Add objects.  We can use strings...
     session.add_all([
         Road(road_name='Jeff Rd', road_geom='LINESTRING(191232 243118,191108 243242)'),
         Road(road_name='Geordie Rd', road_geom='LINESTRING(189141 244158,189265 244817)'),
@@ -197,18 +261,29 @@ if __name__ == '__main__':
         Road(road_name='Phil Tce', road_geom='LINESTRING(190131 224148,190871 228134)'),
     ])
     
-    # GeomFromText can be called directly here as well.
-    session.add(
-        Road(road_name='Dave Cres', road_geom=func.GeomFromText('LINESTRING(198231 263418,198213 268322)', -1)),
-    )
+    # or use an explicit TextualGisElement (similar to saying func.GeomFromText())
+    r = Road(road_name='Dave Cres', road_geom=TextualGisElement('LINESTRING(198231 263418,198213 268322)', -1))
+    session.add(r)
+    
+    # pre flush, the TextualGisElement represents the string we sent.
+    assert str(r.road_geom) == 'LINESTRING(198231 263418,198213 268322)'
+    assert session.scalar(r.road_geom.wkt) == 'LINESTRING(198231 263418,198213 268322)'
     
     session.commit()
+
+    # after flush and/or commit, all the TextualGisElements become PersistentGisElements.
+    assert str(r.road_geom) == "01020000000200000000000000B832084100000000E813104100000000283208410000000088601041"
     
     r1 = session.query(Road).filter(Road.road_name=='Graeme Ave').one()
-
-    # illustrate the overridden __eq__() operator
+    
+    # illustrate the overridden __eq__() operator.
+    
+    # strings come in as TextualGisElements
     r2 = session.query(Road).filter(Road.road_geom == 'LINESTRING(189412 252431,189631 259122)').one()
+    
+    # PersistentGisElements work directly
     r3 = session.query(Road).filter(Road.road_geom == r1.road_geom).one()
+    
     assert r1 is r2 is r3
 
     # illustrate the "intersects" operator
index c547c0e54dd65bd7d22c79d60a5e3d8d66e47a46..de6346b8b09356c50931dfc5be3c77423f1859a0 100644 (file)
@@ -962,7 +962,7 @@ class Connection(Connectable):
 
     # poor man's multimethod/generic function thingy
     executors = {
-        expression._Function: _execute_function,
+        expression.Function: _execute_function,
         expression.ClauseElement: _execute_clauseelement,
         Compiled: _execute_compiled,
         schema.SchemaItem: _execute_default,
index 31fc9ae1e60a5c4773d335979c60881101a592c1..0430f053bd35f0613cdd71225d77d0aafd253ccb 100644 (file)
@@ -463,7 +463,7 @@ class DefaultCompiler(engine.Compiled):
             not isinstance(column.table, sql.Select):
             return _CompileLabel(column, sql._generated_label(column.name))
         elif not isinstance(column, (sql._UnaryExpression, sql._TextClause, sql._BindParamClause)) \
-                and (not hasattr(column, 'name') or isinstance(column, sql._Function)):
+                and (not hasattr(column, 'name') or isinstance(column, sql.Function)):
             return _CompileLabel(column, column.anon_label)
         else:
             return column
index 07df207dd971d1b4eb74fd910a115e4e553f1779..7204e29564b7d991a7641a7a0d6647b1db7ba171 100644 (file)
@@ -820,12 +820,12 @@ def text(text, bind=None, *args, **kwargs):
     return _TextClause(text, bind=bind, *args, **kwargs)
 
 def null():
-    """Return a ``_Null`` object, which compiles to ``NULL`` in a sql statement."""
+    """Return a :class:`_Null` object, which compiles to ``NULL`` in a sql statement."""
 
     return _Null()
 
 class _FunctionGenerator(object):
-    """Generate ``_Function`` objects based on getattr calls."""
+    """Generate :class:`Function` objects based on getattr calls."""
 
     def __init__(self, **opts):
         self.__names = []
@@ -856,7 +856,7 @@ class _FunctionGenerator(object):
             if func is not None:
                 return func(*c, **o)
 
-        return _Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
+        return Function(self.__names[-1], packagenames=self.__names[0:-1], *c, **o)
 
 # "func" global - i.e. func.count()
 func = _FunctionGenerator()
@@ -2228,7 +2228,7 @@ class _CalculatedClause(ColumnElement):
     def _compare_type(self, obj):
         return self.type
 
-class _Function(_CalculatedClause, FromClause):
+class Function(_CalculatedClause, FromClause):
     """Describe a SQL function.
 
     Extends ``_CalculatedClause``, turn the *clauselist* into function
index b57b242f521e4edd267da52d0f20108d609118c9..1bcc6d864f9b03bad9806a0d0330020e3fb0b838 100644 (file)
@@ -1,6 +1,6 @@
 from sqlalchemy import types as sqltypes
 from sqlalchemy.sql.expression import (
-    ClauseList, _Function, _literal_as_binds, text
+    ClauseList, Function, _literal_as_binds, text
     )
 from sqlalchemy.sql import operators
 from sqlalchemy.sql.visitors import VisitableType
@@ -10,7 +10,7 @@ class _GenericMeta(VisitableType):
         args = [_literal_as_binds(c) for c in args]
         return type.__call__(self, *args, **kwargs)
 
-class GenericFunction(_Function):
+class GenericFunction(Function):
     __metaclass__ = _GenericMeta
 
     def __init__(self, type_=None, group=True, args=(), **kwargs):