1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
21 import celery as _celery
22 app = _celery.Celery()
23 app.conf.update(CELERY_RESULT_BACKEND='amqp://')
29 from aria.modeling import models
30 from aria.orchestrator import events
31 from aria.orchestrator.workflows.executor import (
38 from . import MockContext
41 def _get_function(func):
42 return '{module}.{func.__name__}'.format(module=__name__, func=func)
45 def execute_and_assert(executor, storage=None):
46 expected_value = 'value'
47 successful_task = MockContext(
48 storage, task_kwargs=dict(function=_get_function(mock_successful_task))
50 failing_task = MockContext(
51 storage, task_kwargs=dict(function=_get_function(mock_failing_task))
53 task_with_inputs = MockContext(
55 task_kwargs=dict(function=_get_function(mock_task_with_input),
56 arguments={'input': models.Argument.wrap('input', 'value')})
59 for task in [successful_task, failing_task, task_with_inputs]:
60 executor.execute(task)
62 @retrying.retry(stop_max_delay=10000, wait_fixed=100)
64 assert successful_task.states == ['start', 'success']
65 assert failing_task.states == ['start', 'failure']
66 assert task_with_inputs.states == ['start', 'failure']
67 assert isinstance(failing_task.exception, MockException)
68 assert isinstance(task_with_inputs.exception, MockException)
69 assert task_with_inputs.exception.message == expected_value
73 def test_thread_execute(thread_executor):
74 execute_and_assert(thread_executor)
77 def test_process_execute(process_executor, storage):
78 execute_and_assert(process_executor, storage)
81 def mock_successful_task(**_):
85 def mock_failing_task(**_):
89 def mock_task_with_input(input, **_):
90 raise MockException(input)
93 mock_successful_task = app.task(mock_successful_task)
94 mock_failing_task = app.task(mock_failing_task)
95 mock_task_with_input = app.task(mock_task_with_input)
98 class MockException(Exception):
104 _storage = aria.application_model_storage(aria.storage.sql_mapi.SQLAlchemyModelAPI,
105 initiator_kwargs=dict(base_dir=str(tmpdir)))
107 tests.storage.release_sqlite_storage(_storage)
110 @pytest.fixture(params=[
111 (thread.ThreadExecutor, {'pool_size': 1}),
112 (thread.ThreadExecutor, {'pool_size': 2}),
113 # subprocess needs to load a tests module so we explicitly add the root directory as if
114 # the project has been installed in editable mode
115 # (celery.CeleryExecutor, {'app': app})
117 def thread_executor(request):
118 executor_cls, executor_kwargs = request.param
119 result = executor_cls(**executor_kwargs)
125 def process_executor():
126 result = process.ProcessExecutor(python_path=tests.ROOT_DIR)
131 @pytest.fixture(autouse=True)
132 def register_signals():
133 def start_handler(task, *args, **kwargs):
134 task.states.append('start')
136 def success_handler(task, *args, **kwargs):
137 task.states.append('success')
139 def failure_handler(task, exception, *args, **kwargs):
140 task.states.append('failure')
141 task.exception = exception
143 events.start_task_signal.connect(start_handler)
144 events.on_success_task_signal.connect(success_handler)
145 events.on_failure_task_signal.connect(failure_handler)
147 events.start_task_signal.disconnect(start_handler)
148 events.on_success_task_signal.disconnect(success_handler)
149 events.on_failure_task_signal.disconnect(failure_handler)