]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
postgres kickin my ass w00p
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Oct 2005 03:43:22 +0000 (03:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 21 Oct 2005 03:43:22 +0000 (03:43 +0000)
lib/sqlalchemy/ansisql.py
lib/sqlalchemy/databases/postgres.py
lib/sqlalchemy/databases/sqlite.py
lib/sqlalchemy/mapper.py
lib/sqlalchemy/schema.py
lib/sqlalchemy/sql.py
test/mapper.py
test/objectstore.py
test/tables.py

index 95aa47cde42b33e0209ddb8b33c8893538952fce..4f90d485c5d6f4196936b7fa5b70a6856e5f3cf8 100644 (file)
@@ -138,8 +138,11 @@ class ANSICompiler(sql.Compiled):
         while self.binds.setdefault(key, bindparam) is not bindparam:
             key = "%s_%d" % (bindparam.key, count)
             count += 1
-        self.strings[bindparam] = ":" + key
+        self.strings[bindparam] = self.bindparam_string(key)
 
+    def bindparam_string(self, name):
+        return ":" + name
+        
     def visit_alias(self, alias):
         self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name
         self.strings[alias] = self.get_str(alias.selectable)
@@ -221,7 +224,7 @@ class ANSICompiler(sql.Compiled):
             self.binds[b.shortname] = b
             
         text = ("INSERT INTO " + insert_stmt.table.name + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
-         " VALUES (" + string.join([":" + c[1].key for c in colparams], ', ') + ")")
+         " VALUES (" + string.join([self.bindparam_string(c[1].key) for c in colparams], ', ') + ")")
          
         self.strings[insert_stmt] = text
 
@@ -231,7 +234,7 @@ class ANSICompiler(sql.Compiled):
             if isinstance(p, BindParamClause):
                 self.binds[p.key] = p
                 self.binds[p.shortname] = p
-                return ":" + p.key
+                return self.bindparam_string(p.key)
             else:
                 p.accept_visitor(self)
                 if isinstance(p, ClauseElement):
index 375f3c177863c4dbead4081bafb05c3fdd056c58..a4361af6771a692722548a5ded031b8608245d08 100644 (file)
 # along with this library; if not, write to the Free Software
 # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.
 
-import sys, StringIO, string
+import sys, StringIO, string, types, re
 
 import sqlalchemy.sql as sql
+import sqlalchemy.engine as engine
 import sqlalchemy.schema as schema
 import sqlalchemy.ansisql as ansisql
+import sqlalchemy.types as sqltypes
 from sqlalchemy.ansisql import *
 
+class PGNumeric(sqltypes.Numeric):
+    def get_col_spec(self):
+        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+class PGInteger(sqltypes.Integer):
+    def get_col_spec(self):
+        return "INTEGER"
+class PGDateTime(sqltypes.DateTime):
+    def get_col_spec(self):
+        return "TIMESTAMP"
+class PGText(sqltypes.TEXT):
+    def get_col_spec(self):
+        return "TEXT"
+class PGString(sqltypes.String):
+    def get_col_spec(self):
+        return "VARCHAR(%(length)s)" % {'length' : self.length}
+class PGChar(sqltypes.CHAR):
+    def get_col_spec(self):
+        return "CHAR(%(length)s)" % {'length' : self.length}
+class PGBinary(sqltypes.Binary):
+    def get_col_spec(self):
+        return "BLOB"
+class PGBoolean(sqltypes.Boolean):
+    def get_col_spec(self):
+        return "BOOLEAN"
+        
 colspecs = {
-    schema.INT : "INTEGER",
-    schema.CHAR : "CHAR(%(length)s)",
-    schema.VARCHAR : "VARCHAR(%(length)s)",
-    schema.TEXT : "TEXT",
-    schema.FLOAT : "NUMERIC(%(precision)s, %(length)s)",
-    schema.DECIMAL : "NUMERIC(%(precision)s, %(length)s)",
-    schema.TIMESTAMP : "TIMESTAMP",
-    schema.DATETIME : "TIMESTAMP",
-    schema.CLOB : "TEXT",
-    schema.BLOB : "BLOB",
-    schema.BOOLEAN : "BOOLEAN",
+    sqltypes.Integer : PGInteger,
+    sqltypes.Numeric : PGNumeric,
+    sqltypes.DateTime : PGDateTime,
+    sqltypes.String : PGString,
+    sqltypes.Binary : PGBinary,
+    sqltypes.Boolean : PGBoolean,
+    sqltypes.TEXT : PGText,
+    sqltypes.CHAR: PGChar,
 }
 
-
-def engine(**params):
-    return PGSQLEngine(**params)
+def engine(opts, **params):
+    return PGSQLEngine(opts, **params)
 
 class PGSQLEngine(ansisql.ANSISQLEngine):
-    def __init__(self, **params):
+    def __init__(self, opts, module = None, **params):
+        if module is None:
+            self.module = __import__('psycopg2')
+        else:
+            self.module = module
+        self.opts = opts or {}
         ansisql.ANSISQLEngine.__init__(self, **params)
 
     def connect_args(self):
-        return [[], {}]
+        return [[], self.opts]
 
-    def compile(self, statement, bindparams):
-        compiler = PGCompiler(self, statement, bindparams)
-        statement.accept_visitor(compiler)
-        return compiler
+
+    def type_descriptor(self, typeobj):
+        return sqltypes.adapt_type(typeobj, colspecs)
 
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
+    def compiler(self, statement, bindparams):
+        return PGCompiler(self, statement, bindparams)
+
+    def schemagenerator(self, proxy, **params):
+        return PGSchemaGenerator(proxy, **params)
+
+    def reflecttable(self, table):
+        raise "not implemented"
+        
+    def last_inserted_ids(self):
+        return self.context.last_inserted_ids
+
     def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+        if True: return
         if compiled is None: return
         if getattr(compiled, "isinsert", False):
             last_inserted_ids = []
@@ -70,25 +110,33 @@ class PGSQLEngine(ansisql.ANSISQLEngine):
                     last_inserted_ids.append(newid)
             self.context.last_inserted_ids = last_inserted_ids
 
-    def dbapi(self):
-        return None
-#        return psycopg
+    def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs):
+        if compiled is None: return
+        if getattr(compiled, "isinsert", False):
+            self.context.last_inserted_ids = [cursor.lastrowid]
 
-    def columnimpl(self, column):
-        return PGColumnImpl(column)
+    def dbapi(self):
+        return self.module
 
     def reflecttable(self, table):
         raise NotImplementedError()
 
 class PGCompiler(ansisql.ANSICompiler):
-    pass
-
-class PGColumnImpl(sql.ColumnSelectable):
-    def get_specification(self):
-        coltype = self.column.type
-        if isinstance(coltype, types.ClassType):
-            key = coltype
+    def bindparam_string(self, name):
+        return "%(" + name + ")s"
+
+class PGSchemaGenerator(ansisql.ANSISchemaGenerator):
+    def get_column_specification(self, column):
+        colspec = column.name
+        if column.primary_key and isinstance(column.type, types.Integer):
+            colspec += " SERIAL"
         else:
-            key = coltype.__class__
-
-        return self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)}
+            colspec += " " + column.column.type.get_col_spec()
+
+        if not column.nullable:
+            colspec += " NOT NULL"
+        if column.primary_key:
+            colspec += " PRIMARY KEY"
+        if column.foreign_key:
+            colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) 
+        return colspec
index 62443771f31a64b0dcc055eeb77c7a75855efaa2..d613728cbec099409849f496654e79826213b557 100644 (file)
@@ -141,12 +141,12 @@ class SQLiteCompiler(ansisql.ANSICompiler):
 
 class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator):
     def get_column_specification(self, column):
-        colspec = column.name + " " + column.column.type.get_col_spec()
-        if not column.column.nullable:
+        colspec = column.name + " " + column.type.get_col_spec()
+        if not column.nullable:
             colspec += " NOT NULL"
-        if column.column.primary_key:
+        if column.primary_key:
             colspec += " PRIMARY KEY"
-        if column.column.foreign_key:
-            colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) 
+        if column.foreign_key:
+            colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) 
         return colspec
 
index 842d67f1970db29c95b3ddbf8d2cfd6b86e451e6..f64074b8bbfc115fcec9d440e0a155174ba64700 100644 (file)
@@ -397,9 +397,17 @@ class Mapper(object):
                 
 #                print "SAVE_OBJ we are " + hash_key(self) + " obj: " +  obj.__class__.__name__ + repr(id(obj))
                 params = {}
+
                 for col in table.columns:
-                    if col.primary_key and hasattr(obj, "_instance_key"):
-                        params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col)
+                    if col.primary_key:
+                        if hasattr(obj, "_instance_key"):
+                            params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col)
+                        else:
+                            # its an INSERT - if its NULL, leave it out as pgsql doesnt 
+                            # like it for an autoincrement
+                            value = self._getattrbycolumn(obj, col)
+                            if value is not None:
+                                params[col.key] = value
                     else:
                         params[col.key] = self._getattrbycolumn(obj, col)
 
@@ -730,7 +738,7 @@ class PropertyLoader(MapperProperty):
 
     def _compile_synchronizers(self):
         def compile(binary):
-            if binary.operator != '=':
+            if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column):
                 return
 
             if binary.left.table == binary.right.table:
@@ -998,7 +1006,11 @@ class EagerLoader(PropertyLoader):
             towrap = self.parent.table
 
         if self.secondaryjoin is not None:
-            statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.secondaryjoin).outerjoin(self.target, self.primaryjoin)
+            print self.secondary.name
+            print str(self.secondaryjoin)
+            print self.target.name
+            print str(self.primaryjoin)
+            statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.primaryjoin).outerjoin(self.target, self.secondaryjoin)
         else:
             statement._outerjoin = towrap.outerjoin(self.target, self.primaryjoin)
 
index 8952f038be446e6d7898a4eb57aef7309bff774a..9fad004b13d8461b3bf3e889b904994fbd294f9b 100644 (file)
@@ -80,7 +80,7 @@ class Table(SchemaItem):
         self.name = name
         self.columns = OrderedProperties()
         self.c = self.columns
-        self.foreign_keys = OrderedProperties()
+        self.foreign_keys = []
         self.primary_keys = []
         self.engine = engine
         self._impl = self.engine.tableimpl(self)
@@ -204,7 +204,6 @@ class ForeignKey(SchemaItem):
             else:
                 self._column = self._colspec
 
-            self.parent.table.foreign_keys[self._column.key] = self
         return self._column
             
     column = property(lambda s: s._init_column())
@@ -212,6 +211,7 @@ class ForeignKey(SchemaItem):
     def _set_parent(self, column):
         self.parent = column
         self.parent.foreign_key = self
+        self.parent.table.foreign_keys.append(self)
         
 class Sequence(SchemaItem):
     """represents a sequence, which applies to Oracle and Postgres databases."""
index 03cba39563e0c555e70d4bd884ac8023e7678a4d..6322b95227419c440985b15e986aca198472dbea 100644 (file)
@@ -463,6 +463,11 @@ class Join(Selectable):
 
     primary_keys = property (lambda self: [c for c in self.left.columns if c.primary_key] + [c for c in self.right.columns if c.primary_key])
 
+
+    def group_parenthesized(self):
+        """indicates if this Selectable requires parenthesis when grouped into a compound statement"""
+        return False
+
     def hash_key(self):
         return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter))
 
index 36bbbb2542d400201ba9cc5faacc1540a085fc02..978a117158be5718b5a24283508e19d45dbb22f2 100644 (file)
@@ -215,7 +215,7 @@ class EagerTest(AssertMixin):
         c = s.compile()
         self.echo("\n" + str(c) + repr(c.get_params()))
         
-        l = m.instances(s.execute(emailad = 'jack@bean.com'), users.engine)
+        l = m.instances(s.execute(emailad = 'jack@bean.com'))
         self.echo(repr(l))
         
     def testmulti(self):
@@ -308,19 +308,19 @@ class EagerTest(AssertMixin):
         m = mapper(Item, items, properties = dict(
                 keywords = relation(Keyword, keywords, itemkeywords, lazy = False),
             ))
-        l = m.select()
+        l = m.select(order_by=[items.c.item_id, keywords.c.keyword_id])
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])},
-            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 6,'name':'round'}, {'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}])},
+            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 5, 'name':'small'}, {'keyword_id' : 7, 'name':'square'}])},
+            {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}, {'keyword_id' : 6,'name':'round'}])},
             {'item_id' : 4, 'keywords' : (Keyword, [])},
             {'item_id' : 5, 'keywords' : (Keyword, [])}
         )
         
-        l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id))
+        l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id), order_by=[items.c.item_id, keywords.c.keyword_id])
         self.assert_result(l, Item, 
             {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])},
-            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 7}, {'keyword_id' : 5}])},
+            {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])},
         )
     
     def testoneandmany(self):
index 5b5c8591450506a40bda8d5fad8478c399070e79..71ba461266bc7254eca1d86147f81462f134bdf1 100644 (file)
@@ -6,9 +6,10 @@ import sqlalchemy.objectstore as objectstore
 import testbase
 
 echo = testbase.echo
-testbase.echo = False
+#testbase.echo = False
 from tables import *
 
+itemkeywords.delete().execute()
 keywords.delete().execute()
 keywords.insert().execute(
     dict(keyword_id=1, name='blue'),
@@ -47,11 +48,11 @@ class SaveTest(AssertMixin):
         db.echo = False
         objectstore.clear()
         clear_mappers()
-        orders.delete().execute()
+        itemkeywords.delete().execute()
         orderitems.delete().execute()
-        users.delete().execute()
+        orders.delete().execute()
         addresses.delete().execute()
-        itemkeywords.delete().execute()
+        users.delete().execute()
         
         db.echo = e
         
index 6dc1a36cf785a301f30ae95f472d84231fd5a4ee..79f32cf554ae6f1b016f70c7a64e5a056160ed4c 100644 (file)
@@ -10,31 +10,33 @@ __ALL__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'item
 
 ECHO = testbase.echo
 DATA = True
-
-DBTYPE = 'sqlite_memory'
+CREATE = False
+#CREATE = True
+#DBTYPE = 'sqlite_memory'
+DBTYPE = 'postgres'
 #DBTYPE = 'sqlite_file'
 
 if DBTYPE == 'sqlite_memory':
     db = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = testbase.echo)
 elif DBTYPE == 'sqlite_file':
     import sqlalchemy.databases.sqlite as sqllite
-    if os.access('querytest.db', os.F_OK):
-        os.remove('querytest.db')
+#    if os.access('querytest.db', os.F_OK):
#       os.remove('querytest.db')
     db = sqlalchemy.engine.create_engine('sqlite', 'querytest.db', {}, echo = testbase.echo)
 elif DBTYPE == 'postgres':
-    pass
+    db = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=testbase.echo)
 
 db = testbase.EngineAssert(db)
 
 users = Table('users', db,
     Column('user_id', Integer, primary_key = True),
-    Column('user_name', String(20)),
+    Column('user_name', String(40)),
 )
 
 addresses = Table('email_addresses', db,
     Column('address_id', Integer, primary_key = True),
     Column('user_id', Integer, ForeignKey(users.c.user_id)),
-    Column('email_address', String(20)),
+    Column('email_address', String(40)),
 )
 
 orders = Table('orders', db,
@@ -60,25 +62,32 @@ itemkeywords = Table('itemkeywords', db,
     Column('keyword_id', INT, ForeignKey("keywords"))
 )
 
-users.create()
+if CREATE:
+    users.create()
+    addresses.create()
+    orders.create()
+    orderitems.create()
+    keywords.create()
+    itemkeywords.create()
+
 
 if DATA:
+    itemkeywords.delete().execute()
+    keywords.delete().execute()
+    orderitems.delete().execute()
+    orders.delete().execute()
+    addresses.delete().execute()
+    users.delete().execute()
     users.insert().execute(
         dict(user_id = 7, user_name = 'jack'),
         dict(user_id = 8, user_name = 'ed'),
         dict(user_id = 9, user_name = 'fred')
     )
-
-addresses.create()
-if DATA:
     addresses.insert().execute(
         dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"),
         dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"),
         dict(address_id = 3, user_id = 8, email_address = "ed@lala.com")
     )
-
-orders.create()
-if DATA:
     orders.insert().execute(
         dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0),
         dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0),
@@ -86,9 +95,6 @@ if DATA:
         dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1),
         dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0)
     )
-
-orderitems.create()
-if DATA:
     orderitems.insert().execute(
         dict(item_id=1, order_id=2, item_name='item 1'),
         dict(item_id=3, order_id=3, item_name='item 3'),
@@ -96,9 +102,6 @@ if DATA:
         dict(item_id=5, order_id=3, item_name='item 5'),
         dict(item_id=4, order_id=3, item_name='item 4')
     )
-
-keywords.create()
-if DATA:
     keywords.insert().execute(
         dict(keyword_id=1, name='blue'),
         dict(keyword_id=2, name='red'),
@@ -108,9 +111,6 @@ if DATA:
         dict(keyword_id=6, name='round'),
         dict(keyword_id=7, name='square')
     )
-
-itemkeywords.create()
-if DATA:
     itemkeywords.insert().execute(
         dict(keyword_id=2, item_id=1),
         dict(keyword_id=2, item_id=2),
@@ -122,6 +122,7 @@ if DATA:
         dict(keyword_id=5, item_id=2),
         dict(keyword_id=4, item_id=3)
     )
+
 db.connection().commit()