As I was recently working on trying out the Flask web framework for Python, I ended up wanting to access my MySQL database. Recently at work I have been using entity framework and I have gotten quite used to having a good database abstraction that allows programmatic creation of SQL. While such frameworks exist in Python, I thought it would interesting to try writing one. This is one great example of getting carried away with a seemingly simple task.
I aimed for these things:
- Tables should be represented as objects which each instance of the object representing a row
- These objects should be able to generate their own insert, select, and update queries
- Querying the database should be accomplished by logical predicates, not by strings
- Update queries should be optimized to only update those fields which have changed
- The database objects should have support for “immutable” fields that are generated by the database
I also wanted to be able to do relations between tables with foreign keys, but I have decided to stop for now on that. I have a structure outlined, but it isn’t necessary enough at this point since all I wanted was a database abstraction for my simple Flask project. I will probably implement it later.
This can be found as a gist here:
Before going into the code, here is an example of what this abstraction can do as it stands. It directly uses the DbObject and DbQuery-inheriting objects which are shown further down in this post.
from db import * import hashlib def salt_password(user, unsalted): if user is None: return unsalted m = hashlib.sha512() m.update(user.username) m.update(unsalted) return m.hexdigest() class User(DbObject): dbo_tablename = "users" primary_key = IntColumn("id", allow_none=True, mutable=False) username = StringColumn("username", "") password = PasswordColumn("password", salt_password, "") display_name = StringColumn("display_name", "") def __init__(self, **kwargs): DbObject.__init__(self, **kwargs) @classmethod def load(self, cur, username): selection ='u') selection[0].where(selection[1].username == username) result = selection[0].execute(cur) if len(result) == 0: return None else: return result[0] def match_password(self, password): salted = salt_password(self, password) return salted == self.password #assume there is a function get_db defined which returns a PEP-249 #database object def main(): db = get_db() cur = db.cursor() user = User.load(cur, "some username") user.password = "a new password!" db.commit() new_user = User(username="someone", display_name="Their name") new_user.password = "A password that will be hashed" db.commmit() print new_user.primary_key # this will now have a database assigned id
This example first loads a user using a DbSelectQuery. The user is then modified and the DbObject-level function save() is used to save it. Next, a new user is created and saved using the same function. After saving, the primary key will have been populated and will be printed.
Change Tracking Columns
I started out with columns. I needed columns that track changes and have a mapping to an SQL column name. I came up with the following:
class ColumnSet(object): """ Object which is updated by ColumnInstances to inform changes """ def __init__(self): self.__columns = {} # columns are sorted by name i_dict = type(self).__dict__ for attr in i_dict: obj = i_dict[attr] if isinstance(obj, Column): # we get an instance of this column self.__columns[] = ColumnInstance(obj, self) @property def mutated(self): """ Returns the mutated columns for this tracker. """ output = [] for name in self.__columns: column = self.get_column(name) if column.mutated: output.append(column) return output def get_column(self, name): return self.__columns[name] class ColumnInstance(object): """ Per-instance column data. This is used in ColumnSet objects to hold data specific to that particular instance """ def __init__(self, column, owner): """ column: Column object this is created for initial: Initial value """ self.__column = column self.__owner = owner self.update(column.default) def update(self, value): """ Updates the value for this instance, resetting the mutated flag """ if value is None and not self.__column.allow_none: raise ValueError("'None' is invalid for column '" + + "'") if self.__column.validate(value): self.__value = value self.__origvalue = value else: raise ValueError("'" + str(value) + "' is not valid for column '" + + "'") @property def column(self): return self.__column @property def owner(self): return self.__owner @property def mutated(self): return self.__value != self.__origvalue @property def value(self): return self.__value @value.setter def value(self, value): if value is None and not self.__column.allow_none: raise ValueError("'None' is invalid for column '" + + "'") if not self.__column.mutable: raise AttributeError("Column '" + + "' is not" + " mutable") if self.__column.validate(value): self.__value = value else: raise ValueError("'" + value + "' is not valid for column '" + + "'") class Column(object): """ Column descriptor for a column """ def __init__(self, name, default=None, allow_none=False, mutable=True): """ Initializes a column name: Name of the column this maps to default: Default value allow_none: Whether none (db null) values are allowed mutable: Whether this can be mutated by a setter """ self.__name = name self.__allow_none = allow_none self.__mutable = mutable self.__default = default def validate(self, value): """ In a child class, this will validate values being set """ raise NotImplementedError @property def name(self): return self.__name @property def allow_none(self): return self.__allow_none @property def mutable(self): return self.__mutable @property def default(self): return self.__default def __get__(self, owner, ownertype=None): """ Gets the value for this column for the passed owner """ if owner is None: return self if not isinstance(owner, ColumnSet): raise TypeError("Columns are only allowed on ColumnSets") return owner.get_column( def __set__(self, owner, value): """ Sets the value for this column for the passed owner """ if not isinstance(owner, ColumnSet): raise TypeError("Columns are only allowed on ColumnSets") owner.get_column( = value class StringColumn(Column): def validate(self, value): if value is None and self.allow_none: print "nonevalue" return True if isinstance(value, basestring): print "isstr" return True print "not string", value, type(value) return False class IntColumn(Column): def validate(self, value): if value is None and self.allow_none: return True if isinstance(value, int) or isinstance(value, long): return True return False class PasswordColumn(Column): def __init__(self, name, salt_function, default=None, allow_none=False, mutable=True): """ Create a new password column which uses the specified salt function salt_function: a function(self, value) which returns the salted string """ Column.__init__(self, name, default, allow_none, mutable) self.__salt_function = salt_function def validate(self, value): return True def __set__(self, owner, value): salted = self.__salt_function(owner, value) super(PasswordColumn, self).__set__(owner, salted)
The Column class describes the column and is implemented as a descriptor. Each ColumnSet instance contains multiple columns and holds ColumnInstance objects which hold the individual column per-object properties, such as the value and whether it has been mutated or not. Each column type has a validation function to help screen invalid data from the columns. When a ColumnSet is initiated, it scans itself for columns and at that moment creates its ColumnInstances.
Generation of SQL using logical predicates
The next thing I had to create was the database querying structure. I decided that rather than actually using the ColumnInstance or Column objects, I would use a go-between object that can be assigned a “prefix”. A common thing to do in SQL queries is to rename the tables in the query so that you can reference the same table multiple times or use different tables with the same column names. So, for example if I had a table called posts and I also had a table called users and they both shared a column called ‘last_update’, I could assign a prefix ‘p’ to the post columns and a prefix ‘u’ to the user columns so that the final column name would be ‘p.last_update’ and ‘u.last_update’ for posts and users respectively.
Another thing I wanted to do was avoid the usage of SQL in constructing my queries. This is similar to the way that LINQ works for C#: A predicate is specified and later translated into an SQL query or a series of operations in memory depending on what is going on. So, in Python one of my queries looks like so:
class Table(ColumnSet): some_column = StringColumn("column_1", "") another = IntColumn("column_2", 0) a_variable = 5 columns = Table.get_columns('x') # columns with a prefix 'x' query = DbQuery() # This base class just makes a where statement query.where((columns.some_column == "4") & (columns.another > a_variable) print query.sql
This would print out a tuple (" WHERE x.column_1 = %s AND x.column_2 > %s", ["4", 5]). So, how does this work? I used operator overloading to create DbQueryExpression objects. The code is like so:
class DbQueryExpression(object): """ Query expression created from columns, literals, and operators """ def __and__(self, other): return DbQueryConjunction(self, other) def __or__(self, other): return DbQueryDisjunction(self, other) def __str__(self): raise NotImplementedError @property def arguments(self): raise NotImplementedError class DbQueryConjunction(DbQueryExpression): """ Query expression joining together a left and right expression with an AND statement """ def __init__(self, l, r): DbQueryExpression.__ini__(self) self.l = l self.r = r def __str__(self): return str(self.l) + " AND " + str(self.r) @property def arguments(self): return self.l.arguments + self.r.arguments class DbQueryDisjunction(DbQueryExpression): """ Query expression joining together a left and right expression with an OR statement """ def __init__(self, l, r): DbQueryExpression.__init__(self) self.l = l self.r = r def __str__(self): return str(self.r) + " OR " + str(self.r) @property def arguments(self): return self.l.arguments + self.r.arguments class DbQueryColumnComparison(DbQueryExpression): """ Query expression comparing a combination of a column and/or a value """ def __init__(self, l, op, r): DbQueryExpression.__init__(self) self.l = l self.op = op self.r = r def __str__(self): output = "" if isinstance(self.l, DbQueryColumn): prefix = self.l.prefix if prefix is not None: output += prefix + "." output += elif self.l is None: output += "NULL" else: output += "%s" output += self.op if isinstance(self.r, DbQueryColumn): prefix = self.r.prefix if prefix is not None: output += prefix + "." output += elif self.r is None: output += "NULL" else: output += "%s" return output @property def arguments(self): output = [] if not isinstance(self.l, DbQueryColumn) and self.l is not None: output.append(self.l) if not isinstance(self.r, DbQueryColumn) and self.r is not None: output.append(self.r) return output class DbQueryColumnSet(object): """ Represents a set of columns attached to a specific DbOject type. This object dynamically builds itself based on a passed type. The columns attached to this set may be used in DbQueries """ def __init__(self, dbo_type, prefix): d = dbo_type.__dict__ self.__columns = {} for attr in d: obj = d[attr] if isinstance(obj, Column): column = DbQueryColumn(dbo_type, prefix, setattr(self, attr, column) self.__columns[] = column def __len__(self): return len(self.__columns) def __getitem__(self, key): return self.__columns[key] def __iter__(self): return iter(self.__columns) class DbQueryColumn(object): """ Represents a Column object used in a DbQuery """ def __init__(self, dbo_type, prefix, column_name): self.dbo_type = dbo_type = column_name self.prefix = prefix def __lt__(self, other): return DbQueryColumnComparison(self, "<", other) def __le__(self, other): return DbQueryColumnComparison(self, "<=", other) def __eq__(self, other): op = "=" if other is None: op = " IS " return DbQueryColumnComparison(self, op, other) def __ne__(self, other): op = "!=" if other is None: op = " IS NOT " return DbQueryColumnComparison(self, op, other) def __gt__(self, other): return DbQueryColumnComparison(self, ">", other) def __ge__(self, other): return DbQueryColumnComparison(self, ">=", other)
The __str__ function and arguments property return recursively generated expressions using the column prefixes (in the case of __str__) and the arguments (in the case of arguments). As can be seen, this supports parameterization of queries. To be honest, this part was the most fun since I was surprised it was so easy to make predicate expressions using a minimum of classes. One thing that I didn’t like, however, was the fact that the boolean and/or operators cannot be overloaded. For that reason I had to use the bitwise operators, so the expressions aren’t entirely correct when being read.
This DbQueryExpression is fed into my DbQuery object which actually does the translation to SQL. In the example above, we saw that I just passed a logical argument into my where function. This actually was a DbQueryExpression since my overloaded operators create DbQueryExpression objects when they are compared. The DbColumnSet object is an dynamically generated object containing the go-between column objects which is created from a DbObject. We will discuss the DbObject a little further down
The DbQuery objects are implemented as follows:
class DbQueryError(Exception): """ Raised when there is an error constructing a query """ def __init__(self, msg): self.message = msg def __str__(self): return self.message class DbQuery(object): """ Represents a base SQL Query to a database based upon some DbObjects All of the methods implemented here are valid on select, update, and delete statements. """ def __init__(self, execute_filter=None): """ callback: Function to call when the DbQuery is executed """ self.__where = [] self.__limit = None self.__orderby = [] self.__execute_filter = execute_filter def where(self, expression): """Specify an expression to append to the WHERE clause""" self.__where.append(expression) def limit(self, value=None): """Specify the limit to the query""" self.__limit = value @property def sql(self): query = "" args = [] if len(self.__where) > 0: where = self.__where[0] for clause in self.__where[1:]: where = where & clause args = where.arguments query += " WHERE " + str(where) if self.__limit is not None: query += " LIMIT " + self.__limit return query,args def execute(self, cur): """ Executes this query on the passed cursor and returns either the result of the filter function or the cursor if there is no filter function. """ query = self.sql cur.execute(query[0], query[1]) if self.__execute_filter: return self.__execute_filter(self, cur) else: return cur class DbSelectQuery(DbQuery): """ Creates a select query to a database based upon DbObjects """ def __init__(self, execute_filter=None): DbQuery.__init__(self, execute_filter) self.__select = [] self.__froms = [] self.__joins = [] self.__orderby = [] def select(self, *columns): """Specify one or more columns to select""" self.__select += columns def from_table(self, dbo_type, prefix): """Specify a table to select from""" self.__froms.append((dbo_type, prefix)) def join(self, dbo_type, prefix, on): """Specify a table to join to""" self.__joins.append((dbo_type, prefix, on)) def orderby(self, *columns): """Specify one or more columns to order by""" self.__orderby += columns @property def sql(self): query = "SELECT " args = [] if len(self.__select) == 0: raise DbQueryError("No selection in DbSelectQuery") query += ','.join([col.prefix + "." + for col in self.__select]) if len(self.__froms) == 0: raise DbQueryError("No FROM clause in DbSelectQuery") for table in self.__froms: query += " FROM " + table[0].dbo_tablename + " " + table[1] if len(self.__joins) > 0: for join in self.__joins: query += " JOIN " + join[0].dbo_tablename + " " + join[1] + " ON " + str(join[2]) query_parent = super(DbSelectQuery, self).sql query += query_parent[0] args += query_parent[1] if len(self.__orderby) > 0: query += " ORDER BY " + ','.join([col.prefix + "." + for col in self.__orderby]) return query,args class DbInsertQuery(DbQuery): """ Creates an insert query to a database based upon DbObjects. This does not include any where or limit expressions """ def __init__(self, dbo_type, prefix, execute_filter=None): DbQuery.__init__(self, execute_filter) self.table = (dbo_type, prefix) self.__values = [] def value(self, column, value): self.__values.append((column, value)) @property def sql(self): if len(self.__values) == 0: raise DbQueryError("No values in insert") tablename = self.table[0].dbo_tablename query = "INSERT INTO {table} (".format(table=tablename) args = [val[1] for val in self.__values if val[0].prefix == self.table[1]] query += ",".join([val[0].name for val in self.__values if val[0].prefix == self.table[1]]) query += ") VALUES (" query += ",".join(["%s" for x in args]) query += ")" return query,args class DbUpdateQuery(DbQuery): """ Creates an update query to a database based upon DbObjects """ def __init__(self, dbo_type, prefix, execute_filter=None): """ Initialize the update query dbo_type: table type to be updating prefix: Prefix the columns are known under """ DbQuery.__init__(self, execute_filter) self.table = (dbo_type, prefix) self.__updates = [] def update(self, left, right): self.__updates.append((left, right)) @property def sql(self): if len(self.__updates) == 0: raise DbQueryError("No update in DbUpdateQuery") query = "UPDATE " + self.table[0].dbo_tablename + " " + self.table[1] args = [] query += " SET " for update in self.__updates: if isinstance(update[0], DbQueryColumn): query += update[0].prefix + "." + update[0].name else: query += "%s" args.append(update[0]) query += "=" if isinstance(update[1], DbQueryColumn): query += update[1].prefix + "." + update[1].name else: query += "%s" args.append(update[1]) query_parent = super(DbUpdateQuery, self).sql query += query_parent[0] args += query_parent[1] return query, args class DbDeleteQuery(DbQuery): """ Creates a delete query for a database based on a DbObject """ def __init__(self, dbo_type, prefix, execute_filter=None): DbQuery.__init__(self, execute_filter) self.table = (dbo_type, prefix) @property def sql(self): query = "DELETE FROM " + self.table[0].dbo_tablename + " " + self.table[1] args = [] query_parent = super(DbDeleteQuery, self).sql query += query_parent[0] args += query_parent[1] return query, args
Each of the SELECT, INSERT, UPDATE, and DELETE query types inherits from a base DbQuery which does execution and such. I decided to make the DbQuery object take a PEP 249-style cursor object and execute the query itself. My hope is that this will make this a little more portable since, to my knowledge, I didn’t make the queries have any MySQL-specific constructions.
The different query types each implement a variety of statements corresponding to different parts of an SQL query: where(), limit(), orderby(), select(), from_table(), etc. These each take in either a DbQueryColumn (such as is the case with where(), orderby(), select(), etc) or a string to be appended to the query, such as is the case with limit(). I could easily have made limit take in two integers as well, but I was kind of rushing through because I wanted to see if this would even work. The query is built by creating the query object for the basic query type that is desired and then calling its member functions to add things on to the query.
Executing the queries can cause a callback “filter” function to be called which takes in the query and the cursor as arguments. I use this function to create new objects from the data or to update an object. It could probably be used for more clever things as well, but those two cases were my original intent in creating it. If no filter is specified, then the cursor is returned.
Table and row objects
At the highest level of this hierarchy is the DbObject. The DbObject definition actually represents a table in the database with a name and a single primary key column. Each instance represents a row. DbObjects also implement the methods for selecting records of their type and also updating themselves when they are changed. They inherit change tracking from the ColumnSet and use DbQueries to accomplish their querying goals. The code is as follows:
class DbObject(ColumnSet): """ A DbObject is a set of columns linked to a table in the database. This is synonomous to a row. The following class attributes must be set: dbo_tablename : string table name primary_key : Column for the primary key """ def __init__(self, **cols): ColumnSet.__init__(self) for name in cols: c = self.get_column(name) c.update(cols[name]) @classmethod def get_query_columns(self, prefix): return DbQueryColumnSet(self, prefix) @classmethod def select(self, prefix): """ Returns a DbSelectQuery set up for this DbObject """ columns = self.get_query_columns(prefix) def execute(query, cur): output = [] block = cur.fetchmany() while len(block) > 0: for row in block: values = {} i = 0 for name in columns: values[name] = row[i] i += 1 output.append(self(**values)) block = cur.fetchmany() return output query = DbSelectQuery(execute)*[columns[name] for name in columns]) query.from_table(self, prefix) return query, columns def get_primary_key_name(self): return type(self).__dict__['primary_key'].name def save(self, cur): """ Saves any changes to this object to the database """ if self.primary_key is None: # we need to be saved columns = self.get_query_columns('x') def execute(query, cur): self.get_column(self.get_primary_key_name() ).update(cur.lastrowid) selection = [] for name in columns: if name == self.get_primary_key_name(): continue #we have no need to update the primary key column_instance = self.get_column(name) if not column_instance.column.mutable: selection.append(columns[name]) if len(selection) != 0: # we get to select to get additional computed values def execute2(query, cur): row = cur.fetchone() index = 0 for s in selection: self.get_column([index]) index += 1 return True query = DbSelectQuery(execute2)*selection) query.from_table(type(self), 'x') query.where(columns[self.get_primary_key_name()] == self.get_column(self.get_primary_key_name() ).value) return query.execute(cur) return True query = DbInsertQuery(type(self), 'x', execute) for name in columns: column_instance = self.get_column(name) if not column_instance.column.mutable: continue query.value(columns[name], column_instance.value) print query.sql return query.execute(cur) else: # we have been modified modified = self.mutated if len(modified) == 0: return True columns = self.get_query_columns('x') def execute(query, cur): for mod in modified: mod.update(mod.value) return True query = DbUpdateQuery(type(self), 'x', execute) for mod in modified: query.update(columns[], mod.value) query.where(columns[self.get_primary_key_name()] == self.primary_key) return query.execute(cur)
DbObjects require that the inheriting classes define two properties: dbo_tablename and primary_key. dbo_tablename is just a string giving the name of the table in the database and primary_key is a Column that will be used as the primary key.
To select records from the database, the select() function can be called from the class. This sets up a DbSelectQuery which will return an array of the DbObject that it is called for when the query is executed.
One fallacy of this structure is that at the moment it assumes that the primary key won’t be None if it has been set. In other words, the way I did it right now does not allow for null primary keys. The reason it does this is because it says that if the primary key hasn’t been set, it needs to generate a DbInsertQuery for the object when save() is called instead of a DbUpdateQuery. Both insert and update queries do not include every field. Immutable fields are always excluded and then later selected or inferred from the cursor object.