]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
Initial revision
authorMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 Jul 2005 02:43:15 +0000 (02:43 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Fri, 1 Jul 2005 02:43:15 +0000 (02:43 +0000)
lib/sqlalchemy/ansisql.py [new file with mode: 0644]
lib/sqlalchemy/databases/sqlite.py [new file with mode: 0644]
lib/sqlalchemy/engine.py [new file with mode: 0644]

diff --git a/lib/sqlalchemy/ansisql.py b/lib/sqlalchemy/ansisql.py
new file mode 100644 (file)
index 0000000..bd63aa9
--- /dev/null
@@ -0,0 +1,244 @@
+"""defines ANSI SQL operations."""
+
+import sqlalchemy.schema as schema
+
+from sqlalchemy.schema import *
+import sqlalchemy.sql as sql
+import sqlalchemy.engine
+from sqlalchemy.sql import *
+from sqlalchemy.util import *
+import string
+        
+def engine(**params):
+    return ANSISQLEngine(**params)
+    
+class ANSISQLEngine(sqlalchemy.engine.SQLEngine):
+
+    def tableimpl(self, table):
+        return ANSISQLTableImpl(table)
+
+    def schemagenerator(self, proxy, **params):
+        return ANSISchemaGenerator(proxy, **params)
+    
+    def schemadropper(self, proxy, **params):
+        return ANSISchemaDropper(proxy, **params)
+
+    def connect_args(self):
+        return ([],{})
+        
+    def dbapi(self):
+        return object()
+        
+    def compile(self, statement, bindparams):
+        compiler = ANSICompiler(statement, bindparams)
+        
+        statement.accept_visitor(compiler)
+        return compiler
+
+class ANSICompiler(sql.Compiled):
+    def __init__(self, parent, bindparams):
+        self.binds = {}
+        self.bindparams = bindparams
+        self.parent = parent
+        self.froms = {}
+        self.wheres = {}
+        self.strings = {}
+        
+    def get_from_text(self, obj):
+        return self.froms[obj]
+
+    def get_str(self, obj):
+        return self.strings[obj]
+
+    def get_whereclause(self, obj):
+        return self.wheres.get(obj, None)
+        
+    def get_params(self, **params):
+        d = {}
+        for key, value in params.iteritems():
+            try:
+                b = self.binds[key]
+            except KeyError:
+                raise "No such bind param in statement '%s': %s" % (str(self), key)
+            d[b.key] = value
+
+        for b in self.binds.values():
+            if not d.has_key(b.key):
+                d[b.key] = b.value
+
+        return d
+        
+    def visit_column(self, column):
+        if column.table.name is None:
+            self.strings[column] = column.name
+        else:
+            self.strings[column] = "%s.%s" % (column.table.name, column.name)
+
+    def visit_fromclause(self, fromclause):
+        self.froms[fromclause] = fromclause.from_name
+
+    def visit_textclause(self, textclause):
+        if textclause.parens and len(textclause.text):
+            self.strings[textclause] = "(" + textclause.text + ")"
+        else:
+            self.strings[textclause] = textclause.text
+       
+    def visit_compound(self, compound):
+        if compound.operator is None:
+            sep = " "
+        else:
+            sep = " " + compound.operator + " "
+            
+        if compound.parens:
+            self.strings[compound] = "(" + string.join([self.get_str(c) for c in compound.clauses], sep) + ")"
+        else:
+            self.strings[compound] = string.join([self.get_str(c) for c in compound.clauses], sep)
+
+    def visit_clauselist(self, list):
+        self.strings[list] = string.join([self.get_str(c) for c in list.clauses], ', ')
+        
+    def visit_binary(self, binary):
+        
+        if binary.parens:
+           self.strings[binary] = "(" + self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right) + ")"
+        else:
+            self.strings[binary] = self.get_str(binary.left) + " " + str(binary.operator) + " " + self.get_str(binary.right)
+        
+    def visit_bindparam(self, bindparam):
+        self.binds[bindparam.shortname] = bindparam
+        
+        count = 1
+        key = bindparam.key
+        
+        while self.binds.setdefault(key, bindparam) is not bindparam:
+            key = "%s_%d" % (bindparam.key, count)
+            count += 1
+            
+        self.strings[bindparam] = ":" + key
+
+    def visit_alias(self, alias):
+        self.froms[alias] = self.get_from_text(alias.selectable) + " " + alias.name
+
+    def visit_select(self, select):
+        inner_columns = []
+
+        for c in select._raw_columns:
+            for co in c.columns:
+                inner_columns.append(co)
+
+        if select.use_labels:
+            collist = string.join(["%s AS %s" % (c.fullname, c.label) for c in inner_columns], ', ')
+        else:
+            collist = string.join([c.fullname for c in inner_columns], ', ')
+
+        text = "SELECT " + collist + " FROM "
+        
+        whereclause = select.whereclause
+        
+        froms = []
+        for f in select.froms.values():
+
+            # special thingy used by oracle to redefine a join
+            w = self.get_whereclause(f)
+            if w is not None:
+                # TODO: move this more into the oracle module
+                whereclause = sql.and_(w, whereclause)
+                self.visit_compound(whereclause)
+                
+            t = self.get_from_text(f)
+            if t is not None:
+                froms.append(t)
+
+        text += string.join(froms, ', ')                
+
+        if whereclause is not None:
+            t = self.get_str(whereclause)
+            if t:
+                text += " WHERE " + t
+
+        for tup in select._clauses:
+            text += " " + tup[0] + " " + self.get_str(tup[1])
+
+        self.strings[select] = text
+        self.froms[select] = "(" + text + ")"
+
+
+    def visit_table(self, table):
+        self.froms[table] = table.name
+        
+    def visit_join(self, join):
+        if join.isouter:
+            self.froms[join] = ("(" + self.get_from_text(join.left) + " LEFT OUTER JOIN " + self.get_from_text(join.right) + 
+            " ON " + self.get_str(join.onclause) + ")")
+        else:
+            self.froms[join] = ("(" + self.get_from_text(join.left) + " JOIN " + self.get_from_text(join.right) + 
+            " ON " + self.get_str(join.onclause) + ")")
+
+    def visit_insert(self, insert_stmt):
+        colparams = insert_stmt.get_colparams(self.bindparams)
+
+        for c in colparams:
+            b = c[1]
+            self.binds[b.key] = b
+            self.binds[b.shortname] = b
+            
+        text = ("INSERT INTO " + insert_stmt.table.name + " (" + string.join([c[0].name for c in colparams], ', ') + ")" +
+         " VALUES (" + string.join([":" + c[1].key for c in colparams], ', ') + ")")
+         
+        self.strings[insert_stmt] = text
+
+    def visit_update(self, update_stmt):
+        colparams = update_stmt.get_colparams(self.bindparams)
+        
+        for c in colparams:
+            b = c[1]
+            self.binds[b.key] = b
+            self.binds[b.shortname] = b
+            
+        text = "UPDATE " + update_stmt.table.name + " SET " + string.join(["%s=:%s" % (c[0].name, c[1].key) for c in colparams], ', ')
+        
+        if update_stmt.whereclause:
+            text += " WHERE " + self.get_str(update_stmt.whereclause)
+         
+        self.strings[update_stmt] = text
+        
+    def __str__(self):
+        return self.get_str(self.parent)
+
+
+    
+class ANSISQLTableImpl(sql.TableImpl):
+    """Selectable implementation that gets attached to a schema.Table object."""
+    
+    def __init__(self, table):
+        sql.TableImpl.__init__(self)
+        self.table = table
+        self.id = self.table.name
+        
+    def get_from_text(self):
+        return self.table.name
+
+class ANSISchemaGenerator(sqlalchemy.engine.SchemaIterator):
+
+    def visit_table(self, table):
+        self.append("\nCREATE TABLE " + table.name + "(")
+        
+        separator = "\n"
+        
+        for column in table.columns:
+            self.append(separator)
+            separator = ", \n"
+            self.append("\t" + column._get_specification())
+            
+        self.append("\n)\n\n")
+        self.execute()
+
+    def visit_column(self, column):
+        pass
+    
+class ANSISchemaDropper(sqlalchemy.engine.SchemaIterator):
+    def visit_table(self, table):
+        self.append("\nDROP TABLE " + table.name)
+        self.execute()
+
+
diff --git a/lib/sqlalchemy/databases/sqlite.py b/lib/sqlalchemy/databases/sqlite.py
new file mode 100644 (file)
index 0000000..2208923
--- /dev/null
@@ -0,0 +1,53 @@
+import sys, StringIO, string, types
+
+import sqlalchemy.sql as sql
+import sqlalchemy.engine as engine
+import sqlalchemy.schema as schema
+import sqlalchemy.ansisql as ansisql
+from sqlalchemy.ansisql import *
+
+from pysqlite2 import dbapi2 as sqlite
+        
+colspecs = {        
+    schema.INT : "INTEGER",
+    schema.CHAR : "CHAR(%(length)s)",
+    schema.VARCHAR : "VARCHAR(%(length)s)",
+    schema.TEXT : "TEXT",
+    schema.FLOAT : "NUMERIC(%(precision)s, %(length)s)",
+    schema.DECIMAL : "NUMERIC(%(precision)s, %(length)s)",
+    schema.TIMESTAMP : "TIMESTAMP",
+    schema.DATETIME : "TIMESTAMP",
+    schema.CLOB : "TEXT",
+    schema.BLOB : "BLOB",
+    schema.BOOLEAN : "BOOLEAN",
+}
+
+def engine(filename, **params):
+    return SQLiteSQLEngine(filename, **params)
+    
+class SQLiteSQLEngine(ansisql.ANSISQLEngine):
+    def __init__(self, filename, **params):
+        self.filename = filename
+        ansisql.ANSISQLEngine.__init__(self, **params)
+    
+    def connect_args(self):
+        return ([self.filename], {})
+        
+    def dbapi(self):
+        return sqlite
+        
+    def columnimpl(self, column):
+        return SQLiteColumnImpl(column)
+
+class SQLiteColumnImpl(sql.ColumnSelectable):
+    def _get_specification(self):
+        coltype = self.column.type
+        if type(coltype) == types.ClassType:
+            key = coltype
+        else:
+            key = coltype.__class__
+
+        return self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', None), 'length' : getattr(coltype, 'length', None)}
+
+    
+    
diff --git a/lib/sqlalchemy/engine.py b/lib/sqlalchemy/engine.py
new file mode 100644 (file)
index 0000000..811fa94
--- /dev/null
@@ -0,0 +1,117 @@
+"""builds upon the schema and sql packages to provide a central object for tying schema objects
+and sql constructs to database-specific query compilation and execution"""
+
+import sqlalchemy.schema as schema
+import sqlalchemy.pool
+import sqlalchemy.util as util
+import sqlalchemy.sql as sql
+import StringIO
+
+class SchemaIterator(schema.SchemaVisitor):
+    """a visitor that can gather text into a buffer and execute the contents of the buffer."""
+    
+    def __init__(self, sqlproxy, **params):
+        self.sqlproxy = sqlproxy
+        self.buffer = StringIO.StringIO()
+
+    def run(self):
+        raise NotImplementedError()
+
+    def append(self, s):
+        self.buffer.write(s)
+        
+    def execute(self):
+        try:
+            return self.sqlproxy(self.buffer.getvalue())
+        finally:
+            self.buffer.truncate(0)
+
+class SQLEngine(schema.SchemaEngine):
+    """base class for a series of database-specific engines.  serves as an abstract factory for
+    implementation objects as well as database connections, transactions, SQL generators, etc."""
+    
+    def __init__(self, pool = None, echo = False, **params):
+        # get a handle on the connection pool via the connect arguments
+        # this insures the SQLEngine instance integrates with the pool referenced
+        # by direct usage of pool.manager(<module>).connect(*args, **params)
+        (cargs, cparams) = self.connect_args()
+        self._pool = sqlalchemy.pool.manage(self.dbapi()).get_pool(*cargs, **cparams)
+        self._echo = echo
+        self.context = util.ThreadLocal()
+        
+    def schemagenerator(self, proxy, **params):
+        raise NotImplementedError()
+
+    def schemadropper(self, proxy, **params):
+        raise NotImplementedError()
+        
+    def columnimpl(self, column):
+        return sql.ColumnSelectable(column)
+
+    def connect_args(self):
+        raise NotImplementedError()
+        
+    def dbapi(self):
+        raise NotImplementedError()
+
+    def compile(self, statement):
+        raise NotImplementedError()
+
+    def proxy(self):
+        return lambda s, p = None: self.execute(s, p)
+        
+    def connection(self):
+        return self._pool.connect()
+
+    def transaction(self, func):
+        self.begin()
+        try:
+            func()
+        except:
+            self.rollback()
+            raise
+        self.commit()
+            
+    def begin(self):
+        if getattr(self.context, 'transaction', None) is None:
+            conn = self.connection()
+            self.context.transaction = conn
+            self.context.tcount = 1
+        else:
+            self.context.tcount += 1
+            
+    def rollback(self):
+        if self.context.transaction is not None:
+            self.context.transaction.rollback()
+            self.context.transaction = None
+            self.context.tcount = None
+            
+    def commit(self):
+        if self.context.transaction is not None:
+            count = self.context.tcount - 1
+            self.context.tcount = count
+            if count == 0:
+                self.context.transaction.commit()
+                self.context.transaction = None
+                self.context.tcount = None
+                
+    def execute(self, statement, parameters, connection = None, **params):
+        if parameters is None:
+            parameters = {}
+        
+        if self._echo:
+            self.log(statement)
+            self.log(repr(parameters))
+            
+        if connection is None:
+            poolconn = self.connection()
+            c = poolconn.cursor()
+            c.execute(statement, parameters)
+            return c
+        else:
+            c = connection.cursor()
+            c.execute(statement, parameters)
+            return c
+
+    def log(self, msg):
+        print msg