]> git.ipfire.org Git - thirdparty/sqlalchemy/sqlalchemy.git/commitdiff
(no commit message)
authorMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Sep 2005 04:36:05 +0000 (04:36 +0000)
committerMike Bayer <mike_mp@zzzcomputing.com>
Thu, 22 Sep 2005 04:36:05 +0000 (04:36 +0000)
lib/sqlalchemy/databases/sqlite.py

index cef18f1e981682e0bacf5e99d38594d22735e93f..9bb84645155576d09934938b6ab95e3a1f27e00b 100644 (file)
@@ -27,29 +27,60 @@ from sqlalchemy.ansisql import *
 
 from pysqlite2 import dbapi2 as sqlite
 
+class SLNumeric(sqltypes.Numeric):
+    def get_col_spec(self):
+        return "NUMERIC(%(precision)s, %(length)s)" % {'precision': self.precision, 'length' : self.length}
+class SLInteger(sqltypes.Integer):
+    def get_col_spec(self):
+        return "INTEGER"
+class SLDateTime(sqltypes.DateTime):
+    def get_col_spec(self):
+        return "TIMESTAMP"
+class SLText(sqltypes.TEXT):
+    def get_col_spec(self):
+        return "TEXT"
+class SLString(sqltypes.String):
+    def get_col_spec(self):
+        return "VARCHAR(%(length)s)" % {'length' : self.length}
+class SLChar(sqltypes.CHAR):
+    def get_col_spec(self):
+        return "CHAR(%(length)s)" % {'length' : self.length}
+class SLBinary(sqltypes.Binary):
+    def get_col_spec(self):
+        return "BLOB"
+class SLBoolean(sqltypes.Boolean):
+    def get_col_spec(self):
+        return "BOOLEAN"
+        
 colspecs = {
-    sqltypes.INT : "INTEGER",
-    sqltypes.CHAR : "CHAR(%(length)s)",
-    sqltypes.VARCHAR : "VARCHAR(%(length)s)",
-    sqltypes.TEXT : "TEXT",
-    sqltypes.Numeric : "NUMERIC(%(precision)s, %(length)s)",
-    sqltypes.FLOAT : "NUMERIC(%(precision)s, %(length)s)",
-    sqltypes.DECIMAL : "NUMERIC(%(precision)s, %(length)s)",
-    sqltypes.TIMESTAMP : "TIMESTAMP",
-    sqltypes.DATETIME : "TIMESTAMP",
-    sqltypes.CLOB : "TEXT",
-    sqltypes.BLOB : "BLOB",
-    sqltypes.BOOLEAN : "BOOLEAN",
+    sqltypes.INT : SLInteger,
+    sqltypes.CHAR : SLChar,
+    sqltypes.VARCHAR : SLString,
+    sqltypes.TEXT : SLText,
+    sqltypes.Numeric : SLNumeric,
+    sqltypes.TIMESTAMP : SLDateTime,
+    sqltypes.DATETIME : SLDateTime,
+    sqltypes.CLOB : SLText,
+    sqltypes.BLOB : SLBinary,
+    sqltypes.BOOLEAN : SLBoolean,
+    sqltypes.FLOAT : SLNumeric,
+    sqltypes.DECIMAL : SLNumeric,
 }
 
+def type_descriptor(typeobj):
+    try:
+        return typeobj.typeself.adapt(colspecs[typeobj.typeclass])
+    except KeyError:
+        return typeobj.typeself.adapt(typeobj.typeclass)
+
 pragma_names = {
-    'INTEGER' : sqltypes.INT,
-    'VARCHAR' : sqltypes.VARCHAR,
-    'CHAR' : sqltypes.CHAR,
-    'TEXT' : sqltypes.TEXT,
-    'NUMERIC' : sqltypes.FLOAT,
-    'TIMESTAMP' : sqltypes.TIMESTAMP,
-    'BLOB' : sqltypes.BLOB,
+    'INTEGER' : SLInteger,
+    'VARCHAR' : SLString,
+    'CHAR' : SLChar,
+    'TEXT' : SLText,
+    'NUMERIC' : SLNumeric,
+    'TIMESTAMP' : SLDateTime,
+    'BLOB' : SLBinary,
 }
 
 def engine(filename, opts, **params):
@@ -66,6 +97,9 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
         if getattr(compiled, "isinsert", False):
             self.context.last_inserted_ids = [cursor.lastrowid]
 
+    def type_descriptor(self, typeobj):
+        return type_descriptor(typeobj)
+        
     def last_inserted_ids(self):
         return self.context.last_inserted_ids
 
@@ -89,18 +123,18 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
             row = c.fetchone()
             if row is None:
                 break
-            print "row! " + repr(row)
+            #print "row! " + repr(row)
             (name, type, nullable, primary_key) = (row[1], row[2].upper(), not row[3], row[5])
             
             match = re.match(r'(\w+)(\(.*?\))?', type)
             coltype = match.group(1)
             args = match.group(2)
             
-            print "coltype: " + repr(coltype) + " args: " + repr(args)
+            #print "coltype: " + repr(coltype) + " args: " + repr(args)
             coltype = pragma_names[coltype]
             if args is not None:
                 args = re.findall(r'(\d+)', args)
-                print "args! " +repr(args)
+                #print "args! " +repr(args)
                 coltype = coltype(*args)
             table.append_item(schema.Column(name, coltype, primary_key = primary_key, nullable = nullable))
         c = self.execute("PRAGMA foreign_key_list(" + table.name + ")", {})
@@ -109,7 +143,7 @@ class SQLiteSQLEngine(ansisql.ANSISQLEngine):
             if row is None:
                 break
             (tablename, localcol, remotecol) = (row[2], row[3], row[4])
-            print "row! " + repr(row)
+            #print "row! " + repr(row)
             remotetable = Table(tablename, self, autoload = True)
             table.c[localcol].foreign_key = schema.ForeignKey(remotetable.c[remotecol])
             
@@ -119,14 +153,7 @@ class SQLiteCompiler(ansisql.ANSICompiler):
 
 class SQLiteColumnImpl(sql.ColumnSelectable):
     def get_specification(self):
-        coltype = self.column.type
-        if isinstance(coltype, type):
-            key = coltype
-        else:
-            key = coltype.__class__
-            
-
-        colspec = self.name + " " + colspecs[key] % {'precision': getattr(coltype, 'precision', 10), 'length' : getattr(coltype, 'length', 10)}
+        colspec = self.name + " " + self.column.type.get_col_spec()
         if not self.column.nullable:
             colspec += " NOT NULL"
         if self.column.primary_key: