vFW and vDNS support added to azure-plugin
[multicloud/azure.git] / azure / aria / aria-extension-cloudify / src / aria / aria / storage / sql_mapi.py
1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements.  See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License.  You may obtain a copy of the License at
7 #
8 #     http://www.apache.org/licenses/LICENSE-2.0
9 #
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
15
16 """
17 SQLAlchemy implementation of the storage model API ("MAPI").
18 """
19
20 import os
21 import platform
22
23 from sqlalchemy import (
24     create_engine,
25     orm,
26 )
27 from sqlalchemy.exc import SQLAlchemyError
28 from sqlalchemy.orm.exc import StaleDataError
29
30 from aria.utils.collections import OrderedDict
31 from . import (
32     api,
33     exceptions,
34     collection_instrumentation
35 )
36
37 _predicates = {'ge': '__ge__',
38                'gt': '__gt__',
39                'lt': '__lt__',
40                'le': '__le__',
41                'eq': '__eq__',
42                'ne': '__ne__'}
43
44
45 class SQLAlchemyModelAPI(api.ModelAPI):
46     """
47     SQLAlchemy implementation of the storage model API ("MAPI").
48     """
49
50     def __init__(self,
51                  engine,
52                  session,
53                  **kwargs):
54         super(SQLAlchemyModelAPI, self).__init__(**kwargs)
55         self._engine = engine
56         self._session = session
57
58     def get(self, entry_id, include=None, **kwargs):
59         """
60         Returns a single result based on the model class and element ID
61         """
62         query = self._get_query(include, {'id': entry_id})
63         result = query.first()
64
65         if not result:
66             raise exceptions.NotFoundError(
67                 'Requested `{0}` with ID `{1}` was not found'
68                 .format(self.model_cls.__name__, entry_id)
69             )
70         return self._instrument(result)
71
72     def get_by_name(self, entry_name, include=None, **kwargs):
73         assert hasattr(self.model_cls, 'name')
74         result = self.list(include=include, filters={'name': entry_name})
75         if not result:
76             raise exceptions.NotFoundError(
77                 'Requested {0} with name `{1}` was not found'
78                 .format(self.model_cls.__name__, entry_name)
79             )
80         elif len(result) > 1:
81             raise exceptions.StorageError(
82                 'Requested {0} with name `{1}` returned more than 1 value'
83                 .format(self.model_cls.__name__, entry_name)
84             )
85         else:
86             return result[0]
87
88     def list(self,
89              include=None,
90              filters=None,
91              pagination=None,
92              sort=None,
93              **kwargs):
94         query = self._get_query(include, filters, sort)
95
96         results, total, size, offset = self._paginate(query, pagination)
97
98         return ListResult(
99             dict(total=total, size=size, offset=offset),
100             [self._instrument(result) for result in results]
101         )
102
103     def iter(self,
104              include=None,
105              filters=None,
106              sort=None,
107              **kwargs):
108         """
109         Returns a (possibly empty) list of ``model_class`` results.
110         """
111         for result in self._get_query(include, filters, sort):
112             yield self._instrument(result)
113
114     def put(self, entry, **kwargs):
115         """
116         Creatse a ``model_class`` instance from a serializable ``model`` object.
117
118         :param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict``
119          method, and whose attributes match the columns of ``model_class`` (might also be just an
120          instance of ``model_class``)
121         :return: an instance of ``model_class``
122         """
123         self._session.add(entry)
124         self._safe_commit()
125         return entry
126
127     def delete(self, entry, **kwargs):
128         """
129         Deletes a single result based on the model class and element ID.
130         """
131         self._load_relationships(entry)
132         self._session.delete(entry)
133         self._safe_commit()
134         return entry
135
136     def update(self, entry, **kwargs):
137         """
138         Adds ``instance`` to the database session, and attempts to commit.
139
140         :return: updated instance
141         """
142         return self.put(entry)
143
144     def refresh(self, entry):
145         """
146         Reloads the instance with fresh information from the database.
147
148         :param entry: instance to be re-loaded from the database
149         :return: refreshed instance
150         """
151         self._session.refresh(entry)
152         self._load_relationships(entry)
153         return entry
154
155     def _destroy_connection(self):
156         pass
157
158     def _establish_connection(self):
159         pass
160
161     def create(self, checkfirst=True, create_all=True, **kwargs):
162         self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)
163
164         if create_all:
165             # In order to create any models created dynamically (e.g. many-to-many helper tables are
166             # created at runtime).
167             self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst)
168
169     def drop(self):
170         """
171         Drops the table.
172         """
173         self.model_cls.__table__.drop(self._engine)
174
175     def _safe_commit(self):
176         """
177         Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and
178         rolls back if they're caught.
179         """
180         try:
181             self._session.commit()
182         except StaleDataError as e:
183             self._session.rollback()
184             raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
185         except (SQLAlchemyError, ValueError) as e:
186             self._session.rollback()
187             raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
188
189     def _get_base_query(self, include, joins):
190         """
191         Create the initial query from the model class and included columns.
192
193         :param include: (possibly empty) list of columns to include in the query
194         :return: SQLAlchemy AppenderQuery object
195         """
196         # If only some columns are included, query through the session object
197         if include:
198             # Make sure that attributes come before association proxies
199             include.sort(key=lambda x: x.is_clause_element)
200             query = self._session.query(*include)
201         else:
202             # If all columns should be returned, query directly from the model
203             query = self._session.query(self.model_cls)
204
205         query = query.join(*joins)
206         return query
207
208     @staticmethod
209     def _get_joins(model_class, columns):
210         """
211         Gets a list of all the tables on which we need to join.
212
213         :param columns: set of all attributes involved in the query
214         """
215
216         # Using a list instead of a set because order is important
217         joins = OrderedDict()
218         for column_name in columns:
219             column = getattr(model_class, column_name)
220             while not column.is_attribute:
221                 join_attr = column.local_attr
222                 # This is a hack, to deal with the fact that SQLA doesn't
223                 # fully support doing something like: `if join_attr in joins`,
224                 # because some SQLA elements have their own comparators
225                 join_attr_name = str(join_attr)
226                 if join_attr_name not in joins:
227                     joins[join_attr_name] = join_attr
228                 column = column.remote_attr
229
230         return joins.values()
231
232     @staticmethod
233     def _sort_query(query, sort=None):
234         """
235         Adds sorting clauses to the query.
236
237         :param query: base SQL query
238         :param sort: optional dictionary where keys are column names to sort by, and values are
239          the order (asc/desc)
240         :return: SQLAlchemy AppenderQuery object
241         """
242         if sort:
243             for column, order in sort.items():
244                 if order == 'desc':
245                     column = column.desc()
246                 query = query.order_by(column)
247         return query
248
249     def _filter_query(self, query, filters):
250         """
251         Adds filter clauses to the query.
252
253         :param query: base SQL query
254         :param filters: optional dictionary where keys are column names to filter by, and values
255          are values applicable for those columns (or lists of such values)
256         :return: SQLAlchemy AppenderQuery object
257         """
258         return self._add_value_filter(query, filters)
259
260     @staticmethod
261     def _add_value_filter(query, filters):
262         for column, value in filters.items():
263             if isinstance(value, dict):
264                 for predicate, operand in value.items():
265                     query = query.filter(getattr(column, predicate)(operand))
266             elif isinstance(value, (list, tuple)):
267                 query = query.filter(column.in_(value))
268             else:
269                 query = query.filter(column == value)
270
271         return query
272
273     def _get_query(self,
274                    include=None,
275                    filters=None,
276                    sort=None):
277         """
278         Gets a SQL query object based on the params passed.
279
280         :param model_class: SQL database table class
281         :param include: optional list of columns to include in the query
282         :param filters: optional dictionary where keys are column names to filter by, and values
283          are values applicable for those columns (or lists of such values)
284         :param sort: optional dictionary where keys are column names to sort by, and values are the
285          order (asc/desc)
286         :return: sorted and filtered query with only the relevant columns
287         """
288         include, filters, sort, joins = self._get_joins_and_converted_columns(
289             include, filters, sort
290         )
291         filters = self._convert_operands(filters)
292
293         query = self._get_base_query(include, joins)
294         query = self._filter_query(query, filters)
295         query = self._sort_query(query, sort)
296         return query
297
298     @staticmethod
299     def _convert_operands(filters):
300         for column, conditions in filters.items():
301             if isinstance(conditions, dict):
302                 for predicate, operand in conditions.items():
303                     if predicate not in _predicates:
304                         raise exceptions.StorageError(
305                             "{0} is not a valid predicate for filtering. Valid predicates are {1}"
306                             .format(predicate, ', '.join(_predicates.keys())))
307                     del filters[column][predicate]
308                     filters[column][_predicates[predicate]] = operand
309
310
311         return filters
312
313     def _get_joins_and_converted_columns(self,
314                                          include,
315                                          filters,
316                                          sort):
317         """
318         Gets a list of tables on which we need to join and the converted ``include``, ``filters``
319         and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of
320         column names).
321         """
322         include = include or []
323         filters = filters or dict()
324         sort = sort or OrderedDict()
325
326         all_columns = set(include) | set(filters.keys()) | set(sort.keys())
327         joins = self._get_joins(self.model_cls, all_columns)
328
329         include, filters, sort = self._get_columns_from_field_names(
330             include, filters, sort
331         )
332         return include, filters, sort, joins
333
334     def _get_columns_from_field_names(self,
335                                       include,
336                                       filters,
337                                       sort):
338         """
339         Gooes over the optional parameters (include, filters, sort), and replace column names with
340         actual SQLAlechmy column objects.
341         """
342         include = [self._get_column(c) for c in include]
343         filters = dict((self._get_column(c), filters[c]) for c in filters)
344         sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
345
346         return include, filters, sort
347
348     def _get_column(self, column_name):
349         """
350         Returns the column on which an action (filtering, sorting, etc.) would need to be performed.
351         Can be either an attribute of the class, or an association proxy linked to a relationship
352         in the class.
353         """
354         column = getattr(self.model_cls, column_name)
355         if column.is_attribute:
356             return column
357         else:
358             # We need to get to the underlying attribute, so we move on to the
359             # next remote_attr until we reach one
360             while not column.remote_attr.is_attribute:
361                 column = column.remote_attr
362             # Put a label on the remote attribute with the name of the column
363             return column.remote_attr.label(column_name)
364
365     @staticmethod
366     def _paginate(query, pagination):
367         """
368         Paginates the query by size and offset.
369
370         :param query: current SQLAlchemy query object
371         :param pagination: optional dict with size and offset keys
372         :return: tuple with four elements:
373          * results: ``size`` items starting from ``offset``
374          * the total count of items
375          * ``size`` [default: 0]
376          * ``offset`` [default: 0]
377         """
378         if pagination:
379             size = pagination.get('size', 0)
380             offset = pagination.get('offset', 0)
381             total = query.order_by(None).count()  # Fastest way to count
382             results = query.limit(size).offset(offset).all()
383             return results, total, size, offset
384         else:
385             results = query.all()
386             return results, len(results), 0, 0
387
388     @staticmethod
389     def _load_relationships(instance):
390         """
391         Helper method used to overcome a problem where the relationships that rely on joins aren't
392         being loaded automatically.
393         """
394         for rel in instance.__mapper__.relationships:
395             getattr(instance, rel.key)
396
397     def _instrument(self, model):
398         if self._instrumentation:
399             return collection_instrumentation.instrument(self._instrumentation, model, self)
400         else:
401             return model
402
403
404 def init_storage(base_dir, filename='db.sqlite'):
405     """
406     Built-in ModelStorage initiator.
407
408     Creates a SQLAlchemy engine and a session to be passed to the MAPI.
409
410     ``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the
411     location of the database file, and an option filename. This would create an SQLite database.
412
413     :param base_dir: directory of the database
414     :param filename: database file name.
415     :return:
416     """
417     uri = 'sqlite:///{platform_char}{path}'.format(
418         # Handles the windows behavior where there is not root, but drivers.
419         # Thus behaving as relative path.
420         platform_char='' if 'Windows' in platform.system() else '/',
421
422         path=os.path.join(base_dir, filename))
423
424     engine = create_engine(uri, connect_args=dict(timeout=15))
425
426     session_factory = orm.sessionmaker(bind=engine)
427     session = orm.scoped_session(session_factory=session_factory)
428
429     return dict(engine=engine, session=session)
430
431
432 class ListResult(list):
433     """
434     Contains results about the requested items.
435     """
436     def __init__(self, metadata, *args, **qwargs):
437         super(ListResult, self).__init__(*args, **qwargs)
438         self.metadata = metadata
439         self.items = self