1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 SQLAlchemy implementation of the storage model API ("MAPI").
23 from sqlalchemy import (
27 from sqlalchemy.exc import SQLAlchemyError
28 from sqlalchemy.orm.exc import StaleDataError
30 from aria.utils.collections import OrderedDict
34 collection_instrumentation
37 _predicates = {'ge': '__ge__',
45 class SQLAlchemyModelAPI(api.ModelAPI):
47 SQLAlchemy implementation of the storage model API ("MAPI").
54 super(SQLAlchemyModelAPI, self).__init__(**kwargs)
56 self._session = session
58 def get(self, entry_id, include=None, **kwargs):
60 Returns a single result based on the model class and element ID
62 query = self._get_query(include, {'id': entry_id})
63 result = query.first()
66 raise exceptions.NotFoundError(
67 'Requested `{0}` with ID `{1}` was not found'
68 .format(self.model_cls.__name__, entry_id)
70 return self._instrument(result)
72 def get_by_name(self, entry_name, include=None, **kwargs):
73 assert hasattr(self.model_cls, 'name')
74 result = self.list(include=include, filters={'name': entry_name})
76 raise exceptions.NotFoundError(
77 'Requested {0} with name `{1}` was not found'
78 .format(self.model_cls.__name__, entry_name)
81 raise exceptions.StorageError(
82 'Requested {0} with name `{1}` returned more than 1 value'
83 .format(self.model_cls.__name__, entry_name)
94 query = self._get_query(include, filters, sort)
96 results, total, size, offset = self._paginate(query, pagination)
99 dict(total=total, size=size, offset=offset),
100 [self._instrument(result) for result in results]
109 Returns a (possibly empty) list of ``model_class`` results.
111 for result in self._get_query(include, filters, sort):
112 yield self._instrument(result)
114 def put(self, entry, **kwargs):
116 Creatse a ``model_class`` instance from a serializable ``model`` object.
118 :param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict``
119 method, and whose attributes match the columns of ``model_class`` (might also be just an
120 instance of ``model_class``)
121 :return: an instance of ``model_class``
123 self._session.add(entry)
127 def delete(self, entry, **kwargs):
129 Deletes a single result based on the model class and element ID.
131 self._load_relationships(entry)
132 self._session.delete(entry)
136 def update(self, entry, **kwargs):
138 Adds ``instance`` to the database session, and attempts to commit.
140 :return: updated instance
142 return self.put(entry)
144 def refresh(self, entry):
146 Reloads the instance with fresh information from the database.
148 :param entry: instance to be re-loaded from the database
149 :return: refreshed instance
151 self._session.refresh(entry)
152 self._load_relationships(entry)
155 def _destroy_connection(self):
158 def _establish_connection(self):
161 def create(self, checkfirst=True, create_all=True, **kwargs):
162 self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)
165 # In order to create any models created dynamically (e.g. many-to-many helper tables are
166 # created at runtime).
167 self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst)
173 self.model_cls.__table__.drop(self._engine)
175 def _safe_commit(self):
177 Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and
178 rolls back if they're caught.
181 self._session.commit()
182 except StaleDataError as e:
183 self._session.rollback()
184 raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
185 except (SQLAlchemyError, ValueError) as e:
186 self._session.rollback()
187 raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
189 def _get_base_query(self, include, joins):
191 Create the initial query from the model class and included columns.
193 :param include: (possibly empty) list of columns to include in the query
194 :return: SQLAlchemy AppenderQuery object
196 # If only some columns are included, query through the session object
198 # Make sure that attributes come before association proxies
199 include.sort(key=lambda x: x.is_clause_element)
200 query = self._session.query(*include)
202 # If all columns should be returned, query directly from the model
203 query = self._session.query(self.model_cls)
205 query = query.join(*joins)
209 def _get_joins(model_class, columns):
211 Gets a list of all the tables on which we need to join.
213 :param columns: set of all attributes involved in the query
216 # Using a list instead of a set because order is important
217 joins = OrderedDict()
218 for column_name in columns:
219 column = getattr(model_class, column_name)
220 while not column.is_attribute:
221 join_attr = column.local_attr
222 # This is a hack, to deal with the fact that SQLA doesn't
223 # fully support doing something like: `if join_attr in joins`,
224 # because some SQLA elements have their own comparators
225 join_attr_name = str(join_attr)
226 if join_attr_name not in joins:
227 joins[join_attr_name] = join_attr
228 column = column.remote_attr
230 return joins.values()
233 def _sort_query(query, sort=None):
235 Adds sorting clauses to the query.
237 :param query: base SQL query
238 :param sort: optional dictionary where keys are column names to sort by, and values are
240 :return: SQLAlchemy AppenderQuery object
243 for column, order in sort.items():
245 column = column.desc()
246 query = query.order_by(column)
249 def _filter_query(self, query, filters):
251 Adds filter clauses to the query.
253 :param query: base SQL query
254 :param filters: optional dictionary where keys are column names to filter by, and values
255 are values applicable for those columns (or lists of such values)
256 :return: SQLAlchemy AppenderQuery object
258 return self._add_value_filter(query, filters)
261 def _add_value_filter(query, filters):
262 for column, value in filters.items():
263 if isinstance(value, dict):
264 for predicate, operand in value.items():
265 query = query.filter(getattr(column, predicate)(operand))
266 elif isinstance(value, (list, tuple)):
267 query = query.filter(column.in_(value))
269 query = query.filter(column == value)
278 Gets a SQL query object based on the params passed.
280 :param model_class: SQL database table class
281 :param include: optional list of columns to include in the query
282 :param filters: optional dictionary where keys are column names to filter by, and values
283 are values applicable for those columns (or lists of such values)
284 :param sort: optional dictionary where keys are column names to sort by, and values are the
286 :return: sorted and filtered query with only the relevant columns
288 include, filters, sort, joins = self._get_joins_and_converted_columns(
289 include, filters, sort
291 filters = self._convert_operands(filters)
293 query = self._get_base_query(include, joins)
294 query = self._filter_query(query, filters)
295 query = self._sort_query(query, sort)
299 def _convert_operands(filters):
300 for column, conditions in filters.items():
301 if isinstance(conditions, dict):
302 for predicate, operand in conditions.items():
303 if predicate not in _predicates:
304 raise exceptions.StorageError(
305 "{0} is not a valid predicate for filtering. Valid predicates are {1}"
306 .format(predicate, ', '.join(_predicates.keys())))
307 del filters[column][predicate]
308 filters[column][_predicates[predicate]] = operand
313 def _get_joins_and_converted_columns(self,
318 Gets a list of tables on which we need to join and the converted ``include``, ``filters``
319 and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of
322 include = include or []
323 filters = filters or dict()
324 sort = sort or OrderedDict()
326 all_columns = set(include) | set(filters.keys()) | set(sort.keys())
327 joins = self._get_joins(self.model_cls, all_columns)
329 include, filters, sort = self._get_columns_from_field_names(
330 include, filters, sort
332 return include, filters, sort, joins
334 def _get_columns_from_field_names(self,
339 Gooes over the optional parameters (include, filters, sort), and replace column names with
340 actual SQLAlechmy column objects.
342 include = [self._get_column(c) for c in include]
343 filters = dict((self._get_column(c), filters[c]) for c in filters)
344 sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
346 return include, filters, sort
348 def _get_column(self, column_name):
350 Returns the column on which an action (filtering, sorting, etc.) would need to be performed.
351 Can be either an attribute of the class, or an association proxy linked to a relationship
354 column = getattr(self.model_cls, column_name)
355 if column.is_attribute:
358 # We need to get to the underlying attribute, so we move on to the
359 # next remote_attr until we reach one
360 while not column.remote_attr.is_attribute:
361 column = column.remote_attr
362 # Put a label on the remote attribute with the name of the column
363 return column.remote_attr.label(column_name)
366 def _paginate(query, pagination):
368 Paginates the query by size and offset.
370 :param query: current SQLAlchemy query object
371 :param pagination: optional dict with size and offset keys
372 :return: tuple with four elements:
373 * results: ``size`` items starting from ``offset``
374 * the total count of items
375 * ``size`` [default: 0]
376 * ``offset`` [default: 0]
379 size = pagination.get('size', 0)
380 offset = pagination.get('offset', 0)
381 total = query.order_by(None).count() # Fastest way to count
382 results = query.limit(size).offset(offset).all()
383 return results, total, size, offset
385 results = query.all()
386 return results, len(results), 0, 0
389 def _load_relationships(instance):
391 Helper method used to overcome a problem where the relationships that rely on joins aren't
392 being loaded automatically.
394 for rel in instance.__mapper__.relationships:
395 getattr(instance, rel.key)
397 def _instrument(self, model):
398 if self._instrumentation:
399 return collection_instrumentation.instrument(self._instrumentation, model, self)
404 def init_storage(base_dir, filename='db.sqlite'):
406 Built-in ModelStorage initiator.
408 Creates a SQLAlchemy engine and a session to be passed to the MAPI.
410 ``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the
411 location of the database file, and an option filename. This would create an SQLite database.
413 :param base_dir: directory of the database
414 :param filename: database file name.
417 uri = 'sqlite:///{platform_char}{path}'.format(
418 # Handles the windows behavior where there is not root, but drivers.
419 # Thus behaving as relative path.
420 platform_char='' if 'Windows' in platform.system() else '/',
422 path=os.path.join(base_dir, filename))
424 engine = create_engine(uri, connect_args=dict(timeout=15))
426 session_factory = orm.sessionmaker(bind=engine)
427 session = orm.scoped_session(session_factory=session_factory)
429 return dict(engine=engine, session=session)
432 class ListResult(list):
434 Contains results about the requested items.
436 def __init__(self, metadata, *args, **qwargs):
437 super(ListResult, self).__init__(*args, **qwargs)
438 self.metadata = metadata