vFW and vDNS support added to azure-plugin
[multicloud/azure.git] / azure / aria / aria-extension-cloudify / src / aria / aria / storage / sql_mapi.py
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py b/azure/aria/aria-extension-cloudify/src/aria/aria/storage/sql_mapi.py
new file mode 100644 (file)
index 0000000..975ada7
--- /dev/null
@@ -0,0 +1,439 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+SQLAlchemy implementation of the storage model API ("MAPI").
+"""
+
+import os
+import platform
+
+from sqlalchemy import (
+    create_engine,
+    orm,
+)
+from sqlalchemy.exc import SQLAlchemyError
+from sqlalchemy.orm.exc import StaleDataError
+
+from aria.utils.collections import OrderedDict
+from . import (
+    api,
+    exceptions,
+    collection_instrumentation
+)
+
+_predicates = {'ge': '__ge__',
+               'gt': '__gt__',
+               'lt': '__lt__',
+               'le': '__le__',
+               'eq': '__eq__',
+               'ne': '__ne__'}
+
+
+class SQLAlchemyModelAPI(api.ModelAPI):
+    """
+    SQLAlchemy implementation of the storage model API ("MAPI").
+    """
+
+    def __init__(self,
+                 engine,
+                 session,
+                 **kwargs):
+        super(SQLAlchemyModelAPI, self).__init__(**kwargs)
+        self._engine = engine
+        self._session = session
+
+    def get(self, entry_id, include=None, **kwargs):
+        """
+        Returns a single result based on the model class and element ID
+        """
+        query = self._get_query(include, {'id': entry_id})
+        result = query.first()
+
+        if not result:
+            raise exceptions.NotFoundError(
+                'Requested `{0}` with ID `{1}` was not found'
+                .format(self.model_cls.__name__, entry_id)
+            )
+        return self._instrument(result)
+
+    def get_by_name(self, entry_name, include=None, **kwargs):
+        assert hasattr(self.model_cls, 'name')
+        result = self.list(include=include, filters={'name': entry_name})
+        if not result:
+            raise exceptions.NotFoundError(
+                'Requested {0} with name `{1}` was not found'
+                .format(self.model_cls.__name__, entry_name)
+            )
+        elif len(result) > 1:
+            raise exceptions.StorageError(
+                'Requested {0} with name `{1}` returned more than 1 value'
+                .format(self.model_cls.__name__, entry_name)
+            )
+        else:
+            return result[0]
+
+    def list(self,
+             include=None,
+             filters=None,
+             pagination=None,
+             sort=None,
+             **kwargs):
+        query = self._get_query(include, filters, sort)
+
+        results, total, size, offset = self._paginate(query, pagination)
+
+        return ListResult(
+            dict(total=total, size=size, offset=offset),
+            [self._instrument(result) for result in results]
+        )
+
+    def iter(self,
+             include=None,
+             filters=None,
+             sort=None,
+             **kwargs):
+        """
+        Returns a (possibly empty) list of ``model_class`` results.
+        """
+        for result in self._get_query(include, filters, sort):
+            yield self._instrument(result)
+
+    def put(self, entry, **kwargs):
+        """
+        Creatse a ``model_class`` instance from a serializable ``model`` object.
+
+        :param entry: dict with relevant kwargs, or an instance of a class that has a ``to_dict``
+         method, and whose attributes match the columns of ``model_class`` (might also be just an
+         instance of ``model_class``)
+        :return: an instance of ``model_class``
+        """
+        self._session.add(entry)
+        self._safe_commit()
+        return entry
+
+    def delete(self, entry, **kwargs):
+        """
+        Deletes a single result based on the model class and element ID.
+        """
+        self._load_relationships(entry)
+        self._session.delete(entry)
+        self._safe_commit()
+        return entry
+
+    def update(self, entry, **kwargs):
+        """
+        Adds ``instance`` to the database session, and attempts to commit.
+
+        :return: updated instance
+        """
+        return self.put(entry)
+
+    def refresh(self, entry):
+        """
+        Reloads the instance with fresh information from the database.
+
+        :param entry: instance to be re-loaded from the database
+        :return: refreshed instance
+        """
+        self._session.refresh(entry)
+        self._load_relationships(entry)
+        return entry
+
+    def _destroy_connection(self):
+        pass
+
+    def _establish_connection(self):
+        pass
+
+    def create(self, checkfirst=True, create_all=True, **kwargs):
+        self.model_cls.__table__.create(self._engine, checkfirst=checkfirst)
+
+        if create_all:
+            # In order to create any models created dynamically (e.g. many-to-many helper tables are
+            # created at runtime).
+            self.model_cls.metadata.create_all(bind=self._engine, checkfirst=checkfirst)
+
+    def drop(self):
+        """
+        Drops the table.
+        """
+        self.model_cls.__table__.drop(self._engine)
+
+    def _safe_commit(self):
+        """
+        Try to commit changes in the session. Roll back if exception raised SQLAlchemy errors and
+        rolls back if they're caught.
+        """
+        try:
+            self._session.commit()
+        except StaleDataError as e:
+            self._session.rollback()
+            raise exceptions.StorageError('Version conflict: {0}'.format(str(e)))
+        except (SQLAlchemyError, ValueError) as e:
+            self._session.rollback()
+            raise exceptions.StorageError('SQL Storage error: {0}'.format(str(e)))
+
+    def _get_base_query(self, include, joins):
+        """
+        Create the initial query from the model class and included columns.
+
+        :param include: (possibly empty) list of columns to include in the query
+        :return: SQLAlchemy AppenderQuery object
+        """
+        # If only some columns are included, query through the session object
+        if include:
+            # Make sure that attributes come before association proxies
+            include.sort(key=lambda x: x.is_clause_element)
+            query = self._session.query(*include)
+        else:
+            # If all columns should be returned, query directly from the model
+            query = self._session.query(self.model_cls)
+
+        query = query.join(*joins)
+        return query
+
+    @staticmethod
+    def _get_joins(model_class, columns):
+        """
+        Gets a list of all the tables on which we need to join.
+
+        :param columns: set of all attributes involved in the query
+        """
+
+        # Using a list instead of a set because order is important
+        joins = OrderedDict()
+        for column_name in columns:
+            column = getattr(model_class, column_name)
+            while not column.is_attribute:
+                join_attr = column.local_attr
+                # This is a hack, to deal with the fact that SQLA doesn't
+                # fully support doing something like: `if join_attr in joins`,
+                # because some SQLA elements have their own comparators
+                join_attr_name = str(join_attr)
+                if join_attr_name not in joins:
+                    joins[join_attr_name] = join_attr
+                column = column.remote_attr
+
+        return joins.values()
+
+    @staticmethod
+    def _sort_query(query, sort=None):
+        """
+        Adds sorting clauses to the query.
+
+        :param query: base SQL query
+        :param sort: optional dictionary where keys are column names to sort by, and values are
+         the order (asc/desc)
+        :return: SQLAlchemy AppenderQuery object
+        """
+        if sort:
+            for column, order in sort.items():
+                if order == 'desc':
+                    column = column.desc()
+                query = query.order_by(column)
+        return query
+
+    def _filter_query(self, query, filters):
+        """
+        Adds filter clauses to the query.
+
+        :param query: base SQL query
+        :param filters: optional dictionary where keys are column names to filter by, and values
+         are values applicable for those columns (or lists of such values)
+        :return: SQLAlchemy AppenderQuery object
+        """
+        return self._add_value_filter(query, filters)
+
+    @staticmethod
+    def _add_value_filter(query, filters):
+        for column, value in filters.items():
+            if isinstance(value, dict):
+                for predicate, operand in value.items():
+                    query = query.filter(getattr(column, predicate)(operand))
+            elif isinstance(value, (list, tuple)):
+                query = query.filter(column.in_(value))
+            else:
+                query = query.filter(column == value)
+
+        return query
+
+    def _get_query(self,
+                   include=None,
+                   filters=None,
+                   sort=None):
+        """
+        Gets a SQL query object based on the params passed.
+
+        :param model_class: SQL database table class
+        :param include: optional list of columns to include in the query
+        :param filters: optional dictionary where keys are column names to filter by, and values
+         are values applicable for those columns (or lists of such values)
+        :param sort: optional dictionary where keys are column names to sort by, and values are the
+         order (asc/desc)
+        :return: sorted and filtered query with only the relevant columns
+        """
+        include, filters, sort, joins = self._get_joins_and_converted_columns(
+            include, filters, sort
+        )
+        filters = self._convert_operands(filters)
+
+        query = self._get_base_query(include, joins)
+        query = self._filter_query(query, filters)
+        query = self._sort_query(query, sort)
+        return query
+
+    @staticmethod
+    def _convert_operands(filters):
+        for column, conditions in filters.items():
+            if isinstance(conditions, dict):
+                for predicate, operand in conditions.items():
+                    if predicate not in _predicates:
+                        raise exceptions.StorageError(
+                            "{0} is not a valid predicate for filtering. Valid predicates are {1}"
+                            .format(predicate, ', '.join(_predicates.keys())))
+                    del filters[column][predicate]
+                    filters[column][_predicates[predicate]] = operand
+
+
+        return filters
+
+    def _get_joins_and_converted_columns(self,
+                                         include,
+                                         filters,
+                                         sort):
+        """
+        Gets a list of tables on which we need to join and the converted ``include``, ``filters``
+        and ```sort`` arguments (converted to actual SQLAlchemy column/label objects instead of
+        column names).
+        """
+        include = include or []
+        filters = filters or dict()
+        sort = sort or OrderedDict()
+
+        all_columns = set(include) | set(filters.keys()) | set(sort.keys())
+        joins = self._get_joins(self.model_cls, all_columns)
+
+        include, filters, sort = self._get_columns_from_field_names(
+            include, filters, sort
+        )
+        return include, filters, sort, joins
+
+    def _get_columns_from_field_names(self,
+                                      include,
+                                      filters,
+                                      sort):
+        """
+        Gooes over the optional parameters (include, filters, sort), and replace column names with
+        actual SQLAlechmy column objects.
+        """
+        include = [self._get_column(c) for c in include]
+        filters = dict((self._get_column(c), filters[c]) for c in filters)
+        sort = OrderedDict((self._get_column(c), sort[c]) for c in sort)
+
+        return include, filters, sort
+
+    def _get_column(self, column_name):
+        """
+        Returns the column on which an action (filtering, sorting, etc.) would need to be performed.
+        Can be either an attribute of the class, or an association proxy linked to a relationship
+        in the class.
+        """
+        column = getattr(self.model_cls, column_name)
+        if column.is_attribute:
+            return column
+        else:
+            # We need to get to the underlying attribute, so we move on to the
+            # next remote_attr until we reach one
+            while not column.remote_attr.is_attribute:
+                column = column.remote_attr
+            # Put a label on the remote attribute with the name of the column
+            return column.remote_attr.label(column_name)
+
+    @staticmethod
+    def _paginate(query, pagination):
+        """
+        Paginates the query by size and offset.
+
+        :param query: current SQLAlchemy query object
+        :param pagination: optional dict with size and offset keys
+        :return: tuple with four elements:
+         * results: ``size`` items starting from ``offset``
+         * the total count of items
+         * ``size`` [default: 0]
+         * ``offset`` [default: 0]
+        """
+        if pagination:
+            size = pagination.get('size', 0)
+            offset = pagination.get('offset', 0)
+            total = query.order_by(None).count()  # Fastest way to count
+            results = query.limit(size).offset(offset).all()
+            return results, total, size, offset
+        else:
+            results = query.all()
+            return results, len(results), 0, 0
+
+    @staticmethod
+    def _load_relationships(instance):
+        """
+        Helper method used to overcome a problem where the relationships that rely on joins aren't
+        being loaded automatically.
+        """
+        for rel in instance.__mapper__.relationships:
+            getattr(instance, rel.key)
+
+    def _instrument(self, model):
+        if self._instrumentation:
+            return collection_instrumentation.instrument(self._instrumentation, model, self)
+        else:
+            return model
+
+
+def init_storage(base_dir, filename='db.sqlite'):
+    """
+    Built-in ModelStorage initiator.
+
+    Creates a SQLAlchemy engine and a session to be passed to the MAPI.
+
+    ``initiator_kwargs`` must be passed to the ModelStorage which must hold the ``base_dir`` for the
+    location of the database file, and an option filename. This would create an SQLite database.
+
+    :param base_dir: directory of the database
+    :param filename: database file name.
+    :return:
+    """
+    uri = 'sqlite:///{platform_char}{path}'.format(
+        # Handles the windows behavior where there is not root, but drivers.
+        # Thus behaving as relative path.
+        platform_char='' if 'Windows' in platform.system() else '/',
+
+        path=os.path.join(base_dir, filename))
+
+    engine = create_engine(uri, connect_args=dict(timeout=15))
+
+    session_factory = orm.sessionmaker(bind=engine)
+    session = orm.scoped_session(session_factory=session_factory)
+
+    return dict(engine=engine, session=session)
+
+
+class ListResult(list):
+    """
+    Contains results about the requested items.
+    """
+    def __init__(self, metadata, *args, **qwargs):
+        super(ListResult, self).__init__(*args, **qwargs)
+        self.metadata = metadata
+        self.items = self