vFW and vDNS support added to azure-plugin
[multicloud/azure.git] / azure / aria / aria-extension-cloudify / src / aria / aria / orchestrator / workflows / executor / celery.py
diff --git a/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py b/azure/aria/aria-extension-cloudify/src/aria/aria/orchestrator/workflows/executor/celery.py
new file mode 100644 (file)
index 0000000..a2b3513
--- /dev/null
@@ -0,0 +1,97 @@
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Celery task executor.
+"""
+
+import threading
+import Queue
+
+from aria.orchestrator.workflows.executor import BaseExecutor
+
+
+class CeleryExecutor(BaseExecutor):
+    """
+    Celery task executor.
+    """
+
+    def __init__(self, app, *args, **kwargs):
+        super(CeleryExecutor, self).__init__(*args, **kwargs)
+        self._app = app
+        self._started_signaled = False
+        self._started_queue = Queue.Queue(maxsize=1)
+        self._tasks = {}
+        self._results = {}
+        self._receiver = None
+        self._stopped = False
+        self._receiver_thread = threading.Thread(target=self._events_receiver)
+        self._receiver_thread.daemon = True
+        self._receiver_thread.start()
+        self._started_queue.get(timeout=30)
+
+    def _execute(self, ctx):
+        self._tasks[ctx.id] = ctx
+        arguments = dict(arg.unwrapped for arg in ctx.task.arguments.itervalues())
+        arguments['ctx'] = ctx.context
+        self._results[ctx.id] = self._app.send_task(
+            ctx.operation_mapping,
+            kwargs=arguments,
+            task_id=ctx.task.id,
+            queue=self._get_queue(ctx))
+
+    def close(self):
+        self._stopped = True
+        if self._receiver:
+            self._receiver.should_stop = True
+        self._receiver_thread.join()
+
+    @staticmethod
+    def _get_queue(task):
+        return None if task else None  # TODO
+
+    def _events_receiver(self):
+        with self._app.connection() as connection:
+            self._receiver = self._app.events.Receiver(connection, handlers={
+                'task-started': self._celery_task_started,
+                'task-succeeded': self._celery_task_succeeded,
+                'task-failed': self._celery_task_failed,
+            })
+            for _ in self._receiver.itercapture(limit=None, timeout=None, wakeup=True):
+                if not self._started_signaled:
+                    self._started_queue.put(True)
+                    self._started_signaled = True
+                if self._stopped:
+                    return
+
+    def _celery_task_started(self, event):
+        self._task_started(self._tasks[event['uuid']])
+
+    def _celery_task_succeeded(self, event):
+        task, _ = self._remove_task(event['uuid'])
+        self._task_succeeded(task)
+
+    def _celery_task_failed(self, event):
+        task, async_result = self._remove_task(event['uuid'])
+        try:
+            exception = async_result.result
+        except BaseException as e:
+            exception = RuntimeError(
+                'Could not de-serialize exception of task {0} --> {1}: {2}'
+                .format(task.name, type(e).__name__, str(e)))
+        self._task_failed(task, exception=exception)
+
+    def _remove_task(self, task_id):
+        return self._tasks.pop(task_id), self._results.pop(task_id)