# Licensed to the Apache Software Foundation (ASF) under one or more # contributor license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright ownership. # The ASF licenses this file to You under the Apache License, Version 2.0 # (the "License"); you may not use this file except in compliance with # the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """ SQLAlchemy implementation of the storage model API ("MAPI"). """ import os import platform from sqlalchemy import ( create_engine, orm, ) from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm.exc import StaleDataError from aria.utils.collections import OrderedDict from . import ( api, exceptions, collection_instrumentation ) _predicates = {'ge': '__ge__', 'gt': '__gt__', 'lt': '__lt__', 'le': '__le__', 'eq': '__eq__', 'ne': '__ne__'} class SQLAlchemyModelAPI(api.ModelAPI): """ SQLAlchemy implementation of the storage model API ("MAPI"). """ def __init__(self, engine, session, **kwargs): super(SQLAlchemyModelAPI, self).__init__(**kwargs) self._engine = engine self._session = session def get(self, entry_id, include=None, **kwargs): """ Returns a single result based on the model class and element ID """ query = self._get_query(include, {'id': entry_id}) result = query.first() if not result: raise exceptions.NotFoundError( 'Requested `{0}` with ID `{1}` was not found' .format(self.model_cls.__name__, entry_id) ) return self._instrument(result) def get_by_name(self, entry_name, include=None, **kwargs): assert hasattr(self.model_cls, 'name') result = self.list(include=include, filters={'name': entry_name}) if not result: raise exceptions.NotFoundError( 'Requested {0} with name `{1}` was not found' .format(self.model_cls.__name__, entry_name) ) elif len(result) > 1: raise exceptions.StorageError( 'Requested {0} with name `{1}` returned more than 1 value' .format(self.model_cls.__name__, entry_name) ) else: return result[0] def list(self, include=None, filters=None, pagination=None, sort=None, **kwargs): query = self._get_query(include, filters, sort) results, total, size, offset = self._paginate(query, pagination) return ListResult( dict(total=total, size=size, offset=offset), [self._instrument(result) for result in results] ) def iter(self, include=None, filters=None, sort=None, **kwargs): """ Returns a (possibly empty) list of ``model_class`` results. """ for result in self._get_query(include, filters, sort): yield self._instrument(result) def put(self, entry, **kwargs): """ Creatse a ``model_class`` instance from a serializable ``model`` object. :param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict`` method, and whose attributes match the columns of ``model_class`` (might also be just an instance of ``model_class``) :return: an instance of ``model_class`` """ self._session.add(entry) self._safe_commit() return entry def delete(self, entry, **kwargs): """ Deletes a single result based on the model class and element ID. """ self._load_relationships(entry) self._session.delete(entry) self._safe_commit() return entry def update(self, entry, **kwargs): """ Adds ``instance`` to the database session, and attempts to commit. :return: updated instance """ return self.put(entry) def refresh(self, entry): """ Reloads the instance with fresh information from the database. :param entry: instance to be re-loaded from the database :return: refreshed instance """ self._session.refresh(entry) self._load_relationships(entry) return entry def _destroy_connection(self): pass def _establish_connection(self): pass def create(self, checkfirst=True, create_all=True, **kwargs): self.model_cls.__table__.create(self._engine, checkfirst=checkfirst) if create_all: # In order to create any models created dynamically (e.g. many-to-many helper tables are # created at runtime). self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst) def drop(self): """ Drops the table. """ self.model_cls.__table__.drop(self._engine) def _safe_commit(self): """ Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and rolls back if they're caught. """ try: self._session.commit() except StaleDataError as e: self._session.rollback() raise exceptions.StorageError('Version conflict: {0}'.format(str(e))) except (SQLAlchemyError, ValueError) as e: self._session.rollback() raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e))) def _get_base_query(self, include, joins): """ Create the initial query from the model class and included columns. :param include: (possibly empty) list of columns to include in the query :return: SQLAlchemy AppenderQuery object """ # If only some columns are included, query through the session object if include: # Make sure that attributes come before association proxies include.sort(key=lambda x: x.is_clause_element) query = self._session.query(*include) else: # If all columns should be returned, query directly from the model query = self._session.query(self.model_cls) query = query.join(*joins) return query @staticmethod def _get_joins(model_class, columns): """ Gets a list of all the tables on which we need to join. :param columns: set of all attributes involved in the query """ # Using a list instead of a set because order is important joins = OrderedDict() for column_name in columns: column = getattr(model_class, column_name) while not column.is_attribute: join_attr = column.local_attr # This is a hack, to deal with the fact that SQLA doesn't # fully support doing something like: `if join_attr in joins`, # because some SQLA elements have their own comparators join_attr_name = str(join_attr) if join_attr_name not in joins: joins[join_attr_name] = join_attr column = column.remote_attr return joins.values() @staticmethod def _sort_query(query, sort=None): """ Adds sorting clauses to the query. :param query: base SQL query :param sort: optional dictionary where keys are column names to sort by, and values are the order (asc/desc) :return: SQLAlchemy AppenderQuery object """ if sort: for column, order in sort.items(): if order == 'desc': column = column.desc() query = query.order_by(column) return query def _filter_query(self, query, filters): """ Adds filter clauses to the query. :param query: base SQL query :param filters: optional dictionary where keys are column names to filter by, and values are values applicable for those columns (or lists of such values) :return: SQLAlchemy AppenderQuery object """ return self._add_value_filter(query, filters) @staticmethod def _add_value_filter(query, filters): for column, value in filters.items(): if isinstance(value, dict): for predicate, operand in value.items(): query = query.filter(getattr(column, predicate)(operand)) elif isinstance(value, (list, tuple)): query = query.filter(column.in_(value)) else: query = query.filter(column == value) return query def _get_query(self, include=None, filters=None, sort=None): """ Gets a SQL query object based on the params passed. :param model_class: SQL database table class :param include: optional list of columns to include in the query :param filters: optional dictionary where keys are column names to filter by, and values are values applicable for those columns (or lists of such values) :param sort: optional dictionary where keys are column names to sort by, and values are the order (asc/desc) :return: sorted and filtered query with only the relevant columns """ include, filters, sort, joins = self._get_joins_and_converted_columns( include, filters, sort ) filters = self._convert_operands(filters) query = self._get_base_query(include, joins) query = self._filter_query(query, filters) query = self._sort_query(query, sort) return query @staticmethod def _convert_operands(filters): for column, conditions in filters.items(): if isinstance(conditions, dict): for predicate, operand in conditions.items(): if predicate not in _predicates: raise exceptions.StorageError( "{0} is not a valid predicate for filtering. Valid predicates are {1}" .format(predicate, ', '.join(_predicates.keys()))) del filters[column][predicate] filters[column][_predicates[predicate]] = operand return filters def _get_joins_and_converted_columns(self, include, filters, sort): """ Gets a list of tables on which we need to join and the converted ``include``, ``filters`` and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of column names). """ include = include or [] filters = filters or dict() sort = sort or OrderedDict() all_columns = set(include) | set(filters.keys()) | set(sort.keys()) joins = self._get_joins(self.model_cls, all_columns) include, filters, sort = self._get_columns_from_field_names( include, filters, sort ) return include, filters, sort, joins def _get_columns_from_field_names(self, include, filters, sort): """ Gooes over the optional parameters (include, filters, sort), and replace column names with actual SQLAlechmy column objects. """ include = [self._get_column(c) for c in include] filters = dict((self._get_column(c), filters[c]) for c in filters) sort = OrderedDict((self._get_column(c), sort[c]) for c in sort) return include, filters, sort def _get_column(self, column_name): """ Returns the column on which an action (filtering, sorting, etc.) would need to be performed. Can be either an attribute of the class, or an association proxy linked to a relationship in the class. """ column = getattr(self.model_cls, column_name) if column.is_attribute: return column else: # We need to get to the underlying attribute, so we move on to the # next remote_attr until we reach one while not column.remote_attr.is_attribute: column = column.remote_attr # Put a label on the remote attribute with the name of the column return column.remote_attr.label(column_name) @staticmethod def _paginate(query, pagination): """ Paginates the query by size and offset. :param query: current SQLAlchemy query object :param pagination: optional dict with size and offset keys :return: tuple with four elements: * results: ``size`` items starting from ``offset`` * the total count of items * ``size`` [default: 0] * ``offset`` [default: 0] """ if pagination: size = pagination.get('size', 0) offset = pagination.get('offset', 0) total = query.order_by(None).count() # Fastest way to count results = query.limit(size).offset(offset).all() return results, total, size, offset else: results = query.all() return results, len(results), 0, 0 @staticmethod def _load_relationships(instance): """ Helper method used to overcome a problem where the relationships that rely on joins aren't being loaded automatically. """ for rel in instance.__mapper__.relationships: getattr(instance, rel.key) def _instrument(self, model): if self._instrumentation: return collection_instrumentation.instrument(self._instrumentation, model, self) else: return model def init_storage(base_dir, filename='db.sqlite'): """ Built-in ModelStorage initiator. Creates a SQLAlchemy engine and a session to be passed to the MAPI. ``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the location of the database file, and an option filename. This would create an SQLite database. :param base_dir: directory of the database :param filename: database file name. :return: """ uri = 'sqlite:///{platform_char}{path}'.format( # Handles the windows behavior where there is not root, but drivers. # Thus behaving as relative path. platform_char='' if 'Windows' in platform.system() else '/', path=os.path.join(base_dir, filename)) engine = create_engine(uri, connect_args=dict(timeout=15)) session_factory = orm.sessionmaker(bind=engine) session = orm.scoped_session(session_factory=session_factory) return dict(engine=engine, session=session) class ListResult(list): """ Contains results about the requested items. """ def __init__(self, metadata, *args, **qwargs): super(ListResult, self).__init__(*args, **qwargs) self.metadata = metadata self.items = self