From 7f60baef89be3a84db09a8208a9b625af6b19876 Mon Sep 17 00:00:00 2001 From: Mike Bayer Date: Fri, 21 Oct 2005 03:43:22 +0000 Subject: [PATCH] postgres kickin my ass w00p --- lib/sqlalchemy/ansisql.py | 9 +- lib/sqlalchemy/databases/postgres.py | 120 +++++++++++++++++++-------- lib/sqlalchemy/databases/sqlite.py | 10 +-- lib/sqlalchemy/mapper.py | 20 ++++- lib/sqlalchemy/schema.py | 4 +- lib/sqlalchemy/sql.py | 5 ++ test/mapper.py | 12 +-- test/objectstore.py | 9 +- test/tables.py | 47 ++++++----- 9 files changed, 153 insertions(+), 83 deletions(-) diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py index 95aa47cde4..4f90d485c5 100644 --- a/lib/sqlalchemy/ansisql.py +++ b/lib/sqlalchemy/ansisql.py @@ -138,8 +138,11 @@ class ANSICompiler(sql.Compiled): while self.binds.setdefault(key, bindparam) is not bindparam: key = "%s_%d" % (bindparam.key, count) count += 1 - self.strings[bindparam] = ":" + key + self.strings[bindparam] = self.bindparam_string(key) + def bindparam_string(self, name): + return ":" + name + def visit_alias(self, alias): self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name self.strings[alias] = self.get_str(alias.selectable) @@ -221,7 +224,7 @@ class ANSICompiler(sql.Compiled): self.binds[b.shortname] = b text = ("INSERT INTO " + insert_stmt.table.name + " (" + string.join([c[0].name for c in colparams], ', ') + ")" + - " VALUES (" + string.join([":" + c[1].key for c in colparams], ', ') + ")") + " VALUES (" + string.join([self.bindparam_string(c[1].key) for c in colparams], ', ') + ")") self.strings[insert_stmt] = text @@ -231,7 +234,7 @@ class ANSICompiler(sql.Compiled): if isinstance(p, BindParamClause): self.binds[p.key] = p self.binds[p.shortname] = p - return ":" + p.key + return self.bindparam_string(p.key) else: p.accept_visitor(self) if isinstance(p, ClauseElement): diff --git a/lib/sqlalchemy/databases/postgres.py b/lib/sqlalchemy/databases/postgres.py index 375f3c1778..a4361af677 100644 --- a/lib/sqlalchemy/databases/postgres.py +++ b/lib/sqlalchemy/databases/postgres.py @@ -15,47 +15,87 @@ # along with this library; if not, write to the Free Software # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. -import sys, StringIO, string +import sys, StringIO, string, types, re import sqlalchemy.sql as sql +import sqlalchemy.engine as engine import sqlalchemy.schema as schema import sqlalchemy.ansisql as ansisql +import sqlalchemy.types as sqltypes from sqlalchemy.ansisql import * +class PGNumeric(sqltypes.Numeric): + def get_col_spec(self): + return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length} +class PGInteger(sqltypes.Integer): + def get_col_spec(self): + return "INTEGER" +class PGDateTime(sqltypes.DateTime): + def get_col_spec(self): + return "TIMESTAMP" +class PGText(sqltypes.TEXT): + def get_col_spec(self): + return "TEXT" +class PGString(sqltypes.String): + def get_col_spec(self): + return "VARCHAR(%(length)s)" % {'length' : self.length} +class PGChar(sqltypes.CHAR): + def get_col_spec(self): + return "CHAR(%(length)s)" % {'length' : self.length} +class PGBinary(sqltypes.Binary): + def get_col_spec(self): + return "BLOB" +class PGBoolean(sqltypes.Boolean): + def get_col_spec(self): + return "BOOLEAN" + colspecs = { - schema.INT : "INTEGER", - schema.CHAR : "CHAR(%(length)s)", - schema.VARCHAR : "VARCHAR(%(length)s)", - schema.TEXT : "TEXT", - schema.FLOAT : "NUMERIC(%(precision)s, %(length)s)", - schema.DECIMAL : "NUMERIC(%(precision)s, %(length)s)", - schema.TIMESTAMP : "TIMESTAMP", - schema.DATETIME : "TIMESTAMP", - schema.CLOB : "TEXT", - schema.BLOB : "BLOB", - schema.BOOLEAN : "BOOLEAN", + sqltypes.Integer : PGInteger, + sqltypes.Numeric : PGNumeric, + sqltypes.DateTime : PGDateTime, + sqltypes.String : PGString, + sqltypes.Binary : PGBinary, + sqltypes.Boolean : PGBoolean, + sqltypes.TEXT : PGText, + sqltypes.CHAR: PGChar, } - -def engine(**params): - return PGSQLEngine(**params) +def engine(opts, **params): + return PGSQLEngine(opts, **params) class PGSQLEngine(ansisql.ANSISQLEngine): - def __init__(self, **params): + def __init__(self, opts, module = None, **params): + if module is None: + self.module = __import__('psycopg2') + else: + self.module = module + self.opts = opts or {} ansisql.ANSISQLEngine.__init__(self, **params) def connect_args(self): - return [[], {}] + return [[], self.opts] - def compile(self, statement, bindparams): - compiler = PGCompiler(self, statement, bindparams) - statement.accept_visitor(compiler) - return compiler + + def type_descriptor(self, typeobj): + return sqltypes.adapt_type(typeobj, colspecs) def last_inserted_ids(self): return self.context.last_inserted_ids + def compiler(self, statement, bindparams): + return PGCompiler(self, statement, bindparams) + + def schemagenerator(self, proxy, **params): + return PGSchemaGenerator(proxy, **params) + + def reflecttable(self, table): + raise "not implemented" + + def last_inserted_ids(self): + return self.context.last_inserted_ids + def pre_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): + if True: return if compiled is None: return if getattr(compiled, "isinsert", False): last_inserted_ids = [] @@ -70,25 +110,33 @@ class PGSQLEngine(ansisql.ANSISQLEngine): last_inserted_ids.append(newid) self.context.last_inserted_ids = last_inserted_ids - def dbapi(self): - return None -# return psycopg + def post_exec(self, connection, cursor, statement, parameters, echo = None, compiled = None, **kwargs): + if compiled is None: return + if getattr(compiled, "isinsert", False): + self.context.last_inserted_ids = [cursor.lastrowid] - def columnimpl(self, column): - return PGColumnImpl(column) + def dbapi(self): + return self.module def reflecttable(self, table): raise NotImplementedError() class PGCompiler(ansisql.ANSICompiler): - pass - -class PGColumnImpl(sql.ColumnSelectable): - def get_specification(self): - coltype = self.column.type - if isinstance(coltype, types.ClassType): - key = coltype + def bindparam_string(self, name): + return "%(" + name + ")s" + +class PGSchemaGenerator(ansisql.ANSISchemaGenerator): + def get_column_specification(self, column): + colspec = column.name + if column.primary_key and isinstance(column.type, types.Integer): + colspec += " SERIAL" else: - key = coltype.__class__ - - return self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)} + colspec += " " + column.column.type.get_col_spec() + + if not column.nullable: + colspec += " NOT NULL" + if column.primary_key: + colspec += " PRIMARY KEY" + if column.foreign_key: + colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) + return colspec diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py index 62443771f3..d613728cbe 100644 --- a/lib/sqlalchemy/databases/sqlite.py +++ b/lib/sqlalchemy/databases/sqlite.py @@ -141,12 +141,12 @@ class SQLiteCompiler(ansisql.ANSICompiler): class SQLiteSchemaGenerator(ansisql.ANSISchemaGenerator): def get_column_specification(self, column): - colspec = column.name + " " + column.column.type.get_col_spec() - if not column.column.nullable: + colspec = column.name + " " + column.type.get_col_spec() + if not column.nullable: colspec += " NOT NULL" - if column.column.primary_key: + if column.primary_key: colspec += " PRIMARY KEY" - if column.column.foreign_key: - colspec += " REFERENCES %s(%s)" % (column.column.foreign_key.column.table.name, column.column.foreign_key.column.name) + if column.foreign_key: + colspec += " REFERENCES %s(%s)" % (column.foreign_key.column.table.name, column.foreign_key.column.name) return colspec diff --git a/lib/sqlalchemy/mapper.py b/lib/sqlalchemy/mapper.py index 842d67f197..f64074b8bb 100644 --- a/lib/sqlalchemy/mapper.py +++ b/lib/sqlalchemy/mapper.py @@ -397,9 +397,17 @@ class Mapper(object): # print "SAVE_OBJ we are " + hash_key(self) + " obj: " + obj.__class__.__name__ + repr(id(obj)) params = {} + for col in table.columns: - if col.primary_key and hasattr(obj, "_instance_key"): - params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col) + if col.primary_key: + if hasattr(obj, "_instance_key"): + params[col.table.name + "_" + col.key] = self._getattrbycolumn(obj, col) + else: + # its an INSERT - if its NULL, leave it out as pgsql doesnt + # like it for an autoincrement + value = self._getattrbycolumn(obj, col) + if value is not None: + params[col.key] = value else: params[col.key] = self._getattrbycolumn(obj, col) @@ -730,7 +738,7 @@ class PropertyLoader(MapperProperty): def _compile_synchronizers(self): def compile(binary): - if binary.operator != '=': + if binary.operator != '=' or not isinstance(binary.left, schema.Column) or not isinstance(binary.right, schema.Column): return if binary.left.table == binary.right.table: @@ -998,7 +1006,11 @@ class EagerLoader(PropertyLoader): towrap = self.parent.table if self.secondaryjoin is not None: - statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.secondaryjoin).outerjoin(self.target, self.primaryjoin) + print self.secondary.name + print str(self.secondaryjoin) + print self.target.name + print str(self.primaryjoin) + statement._outerjoin = sql.outerjoin(towrap, self.secondary, self.primaryjoin).outerjoin(self.target, self.secondaryjoin) else: statement._outerjoin = towrap.outerjoin(self.target, self.primaryjoin) diff --git a/lib/sqlalchemy/schema.py b/lib/sqlalchemy/schema.py index 8952f038be..9fad004b13 100644 --- a/lib/sqlalchemy/schema.py +++ b/lib/sqlalchemy/schema.py @@ -80,7 +80,7 @@ class Table(SchemaItem): self.name = name self.columns = OrderedProperties() self.c = self.columns - self.foreign_keys = OrderedProperties() + self.foreign_keys = [] self.primary_keys = [] self.engine = engine self._impl = self.engine.tableimpl(self) @@ -204,7 +204,6 @@ class ForeignKey(SchemaItem): else: self._column = self._colspec - self.parent.table.foreign_keys[self._column.key] = self return self._column column = property(lambda s: s._init_column()) @@ -212,6 +211,7 @@ class ForeignKey(SchemaItem): def _set_parent(self, column): self.parent = column self.parent.foreign_key = self + self.parent.table.foreign_keys.append(self) class Sequence(SchemaItem): """represents a sequence, which applies to Oracle and Postgres databases.""" diff --git a/lib/sqlalchemy/sql.py b/lib/sqlalchemy/sql.py index 03cba39563..6322b95227 100644 --- a/lib/sqlalchemy/sql.py +++ b/lib/sqlalchemy/sql.py @@ -463,6 +463,11 @@ class Join(Selectable): primary_keys = property (lambda self: [c for c in self.left.columns if c.primary_key] + [c for c in self.right.columns if c.primary_key]) + + def group_parenthesized(self): + """indicates if this Selectable requires parenthesis when grouped into a compound statement""" + return False + def hash_key(self): return "Join(%s, %s, %s, %s)" % (repr(self.left.hash_key()), repr(self.right.hash_key()), repr(self.onclause.hash_key()), repr(self.isouter)) diff --git a/test/mapper.py b/test/mapper.py index 36bbbb2542..978a117158 100644 --- a/test/mapper.py +++ b/test/mapper.py @@ -215,7 +215,7 @@ class EagerTest(AssertMixin): c = s.compile() self.echo("\n" + str(c) + repr(c.get_params())) - l = m.instances(s.execute(emailad = 'jack@bean.com'), users.engine) + l = m.instances(s.execute(emailad = 'jack@bean.com')) self.echo(repr(l)) def testmulti(self): @@ -308,19 +308,19 @@ class EagerTest(AssertMixin): m = mapper(Item, items, properties = dict( keywords = relation(Keyword, keywords, itemkeywords, lazy = False), )) - l = m.select() + l = m.select(order_by=[items.c.item_id, keywords.c.keyword_id]) self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, - {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 7, 'name':'square'}, {'keyword_id' : 5, 'name':'small'}])}, - {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 6,'name':'round'}, {'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}])}, + {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2, 'name':'red'}, {'keyword_id' : 5, 'name':'small'}, {'keyword_id' : 7, 'name':'square'}])}, + {'item_id' : 3, 'keywords' : (Keyword, [{'keyword_id' : 3,'name':'green'}, {'keyword_id' : 4,'name':'big'}, {'keyword_id' : 6,'name':'round'}])}, {'item_id' : 4, 'keywords' : (Keyword, [])}, {'item_id' : 5, 'keywords' : (Keyword, [])} ) - l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id)) + l = m.select(and_(keywords.c.name == 'red', keywords.c.keyword_id == itemkeywords.c.keyword_id, items.c.item_id==itemkeywords.c.item_id), order_by=[items.c.item_id, keywords.c.keyword_id]) self.assert_result(l, Item, {'item_id' : 1, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 4}, {'keyword_id' : 6}])}, - {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 7}, {'keyword_id' : 5}])}, + {'item_id' : 2, 'keywords' : (Keyword, [{'keyword_id' : 2}, {'keyword_id' : 5}, {'keyword_id' : 7}])}, ) def testoneandmany(self): diff --git a/test/objectstore.py b/test/objectstore.py index 5b5c859145..71ba461266 100644 --- a/test/objectstore.py +++ b/test/objectstore.py @@ -6,9 +6,10 @@ import sqlalchemy.objectstore as objectstore import testbase echo = testbase.echo -testbase.echo = False +#testbase.echo = False from tables import * +itemkeywords.delete().execute() keywords.delete().execute() keywords.insert().execute( dict(keyword_id=1, name='blue'), @@ -47,11 +48,11 @@ class SaveTest(AssertMixin): db.echo = False objectstore.clear() clear_mappers() - orders.delete().execute() + itemkeywords.delete().execute() orderitems.delete().execute() - users.delete().execute() + orders.delete().execute() addresses.delete().execute() - itemkeywords.delete().execute() + users.delete().execute() db.echo = e diff --git a/test/tables.py b/test/tables.py index 6dc1a36cf7..79f32cf554 100644 --- a/test/tables.py +++ b/test/tables.py @@ -10,31 +10,33 @@ __ALL__ = ['db', 'users', 'addresses', 'orders', 'orderitems', 'keywords', 'item ECHO = testbase.echo DATA = True - -DBTYPE = 'sqlite_memory' +CREATE = False +#CREATE = True +#DBTYPE = 'sqlite_memory' +DBTYPE = 'postgres' #DBTYPE = 'sqlite_file' if DBTYPE == 'sqlite_memory': db = sqlalchemy.engine.create_engine('sqlite', ':memory:', {}, echo = testbase.echo) elif DBTYPE == 'sqlite_file': import sqlalchemy.databases.sqlite as sqllite - if os.access('querytest.db', os.F_OK): - os.remove('querytest.db') +# if os.access('querytest.db', os.F_OK): + # os.remove('querytest.db') db = sqlalchemy.engine.create_engine('sqlite', 'querytest.db', {}, echo = testbase.echo) elif DBTYPE == 'postgres': - pass + db = sqlalchemy.engine.create_engine('postgres', {'database':'test', 'host':'127.0.0.1', 'user':'scott', 'password':'tiger'}, echo=testbase.echo) db = testbase.EngineAssert(db) users = Table('users', db, Column('user_id', Integer, primary_key = True), - Column('user_name', String(20)), + Column('user_name', String(40)), ) addresses = Table('email_addresses', db, Column('address_id', Integer, primary_key = True), Column('user_id', Integer, ForeignKey(users.c.user_id)), - Column('email_address', String(20)), + Column('email_address', String(40)), ) orders = Table('orders', db, @@ -60,25 +62,32 @@ itemkeywords = Table('itemkeywords', db, Column('keyword_id', INT, ForeignKey("keywords")) ) -users.create() +if CREATE: + users.create() + addresses.create() + orders.create() + orderitems.create() + keywords.create() + itemkeywords.create() + if DATA: + itemkeywords.delete().execute() + keywords.delete().execute() + orderitems.delete().execute() + orders.delete().execute() + addresses.delete().execute() + users.delete().execute() users.insert().execute( dict(user_id = 7, user_name = 'jack'), dict(user_id = 8, user_name = 'ed'), dict(user_id = 9, user_name = 'fred') ) - -addresses.create() -if DATA: addresses.insert().execute( dict(address_id = 1, user_id = 7, email_address = "jack@bean.com"), dict(address_id = 2, user_id = 8, email_address = "ed@wood.com"), dict(address_id = 3, user_id = 8, email_address = "ed@lala.com") ) - -orders.create() -if DATA: orders.insert().execute( dict(order_id = 1, user_id = 7, description = 'order 1', isopen=0), dict(order_id = 2, user_id = 9, description = 'order 2', isopen=0), @@ -86,9 +95,6 @@ if DATA: dict(order_id = 4, user_id = 9, description = 'order 4', isopen=1), dict(order_id = 5, user_id = 7, description = 'order 5', isopen=0) ) - -orderitems.create() -if DATA: orderitems.insert().execute( dict(item_id=1, order_id=2, item_name='item 1'), dict(item_id=3, order_id=3, item_name='item 3'), @@ -96,9 +102,6 @@ if DATA: dict(item_id=5, order_id=3, item_name='item 5'), dict(item_id=4, order_id=3, item_name='item 4') ) - -keywords.create() -if DATA: keywords.insert().execute( dict(keyword_id=1, name='blue'), dict(keyword_id=2, name='red'), @@ -108,9 +111,6 @@ if DATA: dict(keyword_id=6, name='round'), dict(keyword_id=7, name='square') ) - -itemkeywords.create() -if DATA: itemkeywords.insert().execute( dict(keyword_id=2, item_id=1), dict(keyword_id=2, item_id=2), @@ -122,6 +122,7 @@ if DATA: dict(keyword_id=5, item_id=2), dict(keyword_id=4, item_id=3) ) + db.connection().commit() -- 2.47.2