1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
24 from fabric.contrib import files
25 from fabric import context_managers
27 from aria.modeling import models
28 from aria.orchestrator import events
29 from aria.orchestrator import workflow
30 from aria.orchestrator.workflows import api
31 from aria.orchestrator.workflows.executor import process
32 from aria.orchestrator.workflows.core import (engine, graph_compiler)
33 from aria.orchestrator.workflows.exceptions import ExecutorException
34 from aria.orchestrator.exceptions import (TaskAbortException, TaskRetryException)
35 from aria.orchestrator.execution_plugin import operations
36 from aria.orchestrator.execution_plugin import constants
37 from aria.orchestrator.execution_plugin.exceptions import (ProcessException, TaskException)
38 from aria.orchestrator.execution_plugin.ssh import operations as ssh_operations
40 from tests import mock, storage, resources
41 from tests.orchestrator.workflows.helpers import events_collector
44 _CUSTOM_BASE_DIR = '/tmp/new-aria-ctx'
47 'host_string': 'localhost',
49 # 'password': 'travis',
50 'key_filename': '/home/travis/.ssh/id_rsa'
54 # To help debug in case of connection failures
55 logging.getLogger('paramiko.transport').addHandler(logging.StreamHandler())
56 logging.getLogger('paramiko.transport').setLevel(logging.DEBUG)
59 @pytest.mark.skipif(not os.environ.get('TRAVIS'), reason='actual ssh server required')
60 class TestWithActualSSHServer(object):
62 def test_run_script_basic(self):
63 expected_attribute_value = 'some_value'
64 props = self._execute(env={'test_value': expected_attribute_value})
65 assert props['test_value'].value == expected_attribute_value
67 @pytest.mark.skip(reason='sudo privileges are required')
68 def test_run_script_as_sudo(self):
69 self._execute(use_sudo=True)
71 assert files.exists('/opt/test_dir')
72 fabric.api.sudo('rm -rf /opt/test_dir')
74 def test_run_script_default_base_dir(self):
75 props = self._execute()
76 assert props['work_dir'].value == '{0}/work'.format(constants.DEFAULT_BASE_DIR)
78 @pytest.mark.skip(reason='Re-enable once output from process executor can be captured')
79 @pytest.mark.parametrize('hide_groups', [[], ['everything']])
80 def test_run_script_with_hide(self, hide_groups):
81 self._execute(hide_output=hide_groups)
83 expected_log_message = ('[localhost] run: source {0}/scripts/'
84 .format(constants.DEFAULT_BASE_DIR))
86 assert expected_log_message not in output
88 assert expected_log_message in output
90 def test_run_script_process_config(self):
91 expected_env_value = 'test_value_env'
92 expected_arg1_value = 'test_value_arg1'
93 expected_arg2_value = 'test_value_arg2'
95 expected_base_dir = _CUSTOM_BASE_DIR
96 props = self._execute(
97 env={'test_value_env': expected_env_value},
99 'args': [expected_arg1_value, expected_arg2_value],
101 'base_dir': expected_base_dir
103 assert props['env_value'].value == expected_env_value
104 assert len(props['bash_version'].value) > 0
105 assert props['arg1_value'].value == expected_arg1_value
106 assert props['arg2_value'].value == expected_arg2_value
107 assert props['cwd'].value == expected_cwd
108 assert props['ctx_path'].value == '{0}/ctx'.format(expected_base_dir)
110 def test_run_script_command_prefix(self):
111 props = self._execute(process={'command_prefix': 'bash -i'})
112 assert 'i' in props['dollar_dash'].value
114 def test_run_script_reuse_existing_ctx(self):
115 expected_test_value_1 = 'test_value_1'
116 expected_test_value_2 = 'test_value_2'
117 props = self._execute(
118 test_operations=['{0}_1'.format(self.test_name),
119 '{0}_2'.format(self.test_name)],
120 env={'test_value1': expected_test_value_1,
121 'test_value2': expected_test_value_2})
122 assert props['test_value1'].value == expected_test_value_1
123 assert props['test_value2'].value == expected_test_value_2
125 def test_run_script_download_resource_plain(self, tmpdir):
126 resource = tmpdir.join('resource')
127 resource.write('content')
128 self._upload(str(resource), 'test_resource')
129 props = self._execute()
130 assert props['test_value'].value == 'content'
132 def test_run_script_download_resource_and_render(self, tmpdir):
133 resource = tmpdir.join('resource')
134 resource.write('{{ctx.service.name}}')
135 self._upload(str(resource), 'test_resource')
136 props = self._execute()
137 assert props['test_value'].value == self._workflow_context.service.name
139 @pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
140 def test_run_script_inputs_as_env_variables_no_override(self, value):
141 props = self._execute(custom_input=value)
142 return_value = props['test_value'].value
143 expected = return_value if isinstance(value, basestring) else json.loads(return_value)
144 assert value == expected
146 @pytest.mark.parametrize('value', ['string-value', [1, 2, 3], {'key': 'value'}])
147 def test_run_script_inputs_as_env_variables_process_env_override(self, value):
148 props = self._execute(custom_input='custom-input-value',
149 env={'custom_env_var': value})
150 return_value = props['test_value'].value
151 expected = return_value if isinstance(value, basestring) else json.loads(return_value)
152 assert value == expected
154 def test_run_script_error_in_script(self):
155 exception = self._execute_and_get_task_exception()
156 assert isinstance(exception, TaskException)
158 def test_run_script_abort_immediate(self):
159 exception = self._execute_and_get_task_exception()
160 assert isinstance(exception, TaskAbortException)
161 assert exception.message == 'abort-message'
163 def test_run_script_retry(self):
164 exception = self._execute_and_get_task_exception()
165 assert isinstance(exception, TaskRetryException)
166 assert exception.message == 'retry-message'
168 def test_run_script_abort_error_ignored_by_script(self):
169 exception = self._execute_and_get_task_exception()
170 assert isinstance(exception, TaskAbortException)
171 assert exception.message == 'abort-message'
173 def test_run_commands(self):
174 temp_file_path = '/tmp/very_temporary_file'
175 with self._ssh_env():
176 if files.exists(temp_file_path):
177 fabric.api.run('rm {0}'.format(temp_file_path))
178 self._execute(commands=['touch {0}'.format(temp_file_path)])
179 with self._ssh_env():
180 assert files.exists(temp_file_path)
181 fabric.api.run('rm {0}'.format(temp_file_path))
183 @pytest.fixture(autouse=True)
184 def _setup(self, request, workflow_context, executor, capfd):
185 self._workflow_context = workflow_context
186 self._executor = executor
188 self.test_name = request.node.originalname or request.node.name
189 with self._ssh_env():
190 for directory in [constants.DEFAULT_BASE_DIR, _CUSTOM_BASE_DIR]:
191 if files.exists(directory):
192 fabric.api.run('rm -rf {0}'.format(directory))
194 @contextlib.contextmanager
196 with self._capfd.disabled():
197 with context_managers.settings(fabric.api.hide('everything'),
207 test_operations=None,
209 process = process or {}
211 process.setdefault('env', {}).update(env)
213 test_operations = test_operations or [self.test_name]
215 local_script_path = os.path.join(resources.DIR, 'scripts', 'test_ssh.sh')
216 script_path = os.path.basename(local_script_path)
217 self._upload(local_script_path, script_path)
220 operation = operations.run_commands_with_ssh
222 operation = operations.run_script_with_ssh
224 node = self._workflow_context.model.node.get_by_name(mock.models.DEPENDENCY_NODE_NAME)
226 'script_path': script_path,
227 'fabric_env': _FABRIC_ENV,
229 'use_sudo': use_sudo,
230 'custom_env_var': custom_input,
231 'test_operation': '',
234 arguments['hide_output'] = hide_output
236 arguments['commands'] = commands
237 interface = mock.models.create_interface(
241 operation_kwargs=dict(
242 function='{0}.{1}'.format(
247 node.interfaces[interface.name] = interface
250 def mock_workflow(ctx, graph):
252 for test_operation in test_operations:
253 op_arguments = arguments.copy()
254 op_arguments['test_operation'] = test_operation
255 ops.append(api.task.OperationTask(
257 interface_name='test',
259 arguments=op_arguments))
263 tasks_graph = mock_workflow(ctx=self._workflow_context) # pylint: disable=no-value-for-parameter
264 graph_compiler.GraphCompiler(
265 self._workflow_context, self._executor.__class__).compile(tasks_graph)
266 eng = engine.Engine({self._executor.__class__: self._executor})
267 eng.execute(self._workflow_context)
268 return self._workflow_context.model.node.get_by_name(
269 mock.models.DEPENDENCY_NODE_NAME).attributes
271 def _execute_and_get_task_exception(self, *args, **kwargs):
272 signal = events.on_failure_task_signal
273 with events_collector(signal) as collected:
274 with pytest.raises(ExecutorException):
275 self._execute(*args, **kwargs)
276 return collected[signal][0]['kwargs']['exception']
278 def _upload(self, source, path):
279 self._workflow_context.resource.service.upload(
280 entry_id=str(self._workflow_context.service.id),
286 result = process.ProcessExecutor()
293 def workflow_context(self, tmpdir):
294 workflow_context = mock.context.simple(str(tmpdir))
295 workflow_context.states = []
296 workflow_context.exception = None
297 yield workflow_context
298 storage.release_sqlite_storage(workflow_context.model)
301 class TestFabricEnvHideGroupsAndRunCommands(object):
303 def test_fabric_env_default_override(self):
304 # first sanity for no override
306 assert self.mock.settings_merged['timeout'] == constants.FABRIC_ENV_DEFAULTS['timeout']
308 invocation_fabric_env = self.default_fabric_env.copy()
310 invocation_fabric_env['timeout'] = timeout
311 self._run(fabric_env=invocation_fabric_env)
312 assert self.mock.settings_merged['timeout'] == timeout
314 def test_implicit_host_string(self, mocker):
315 expected_host_address = '1.1.1.1'
316 mocker.patch.object(self._Ctx.task.actor, 'host')
317 mocker.patch.object(self._Ctx.task.actor.host, 'host_address', expected_host_address)
318 fabric_env = self.default_fabric_env.copy()
319 del fabric_env['host_string']
320 self._run(fabric_env=fabric_env)
321 assert self.mock.settings_merged['host_string'] == expected_host_address
323 def test_explicit_host_string(self):
324 fabric_env = self.default_fabric_env.copy()
325 host_string = 'explicit_host_string'
326 fabric_env['host_string'] = host_string
327 self._run(fabric_env=fabric_env)
328 assert self.mock.settings_merged['host_string'] == host_string
330 def test_override_warn_only(self):
331 fabric_env = self.default_fabric_env.copy()
332 self._run(fabric_env=fabric_env)
333 assert self.mock.settings_merged['warn_only'] is True
334 fabric_env = self.default_fabric_env.copy()
335 fabric_env['warn_only'] = False
336 self._run(fabric_env=fabric_env)
337 assert self.mock.settings_merged['warn_only'] is False
339 def test_missing_host_string(self):
340 with pytest.raises(TaskAbortException) as exc_ctx:
341 fabric_env = self.default_fabric_env.copy()
342 del fabric_env['host_string']
343 self._run(fabric_env=fabric_env)
344 assert '`host_string` not supplied' in str(exc_ctx.value)
346 def test_missing_user(self):
347 with pytest.raises(TaskAbortException) as exc_ctx:
348 fabric_env = self.default_fabric_env.copy()
349 del fabric_env['user']
350 self._run(fabric_env=fabric_env)
351 assert '`user` not supplied' in str(exc_ctx.value)
353 def test_missing_key_or_password(self):
354 with pytest.raises(TaskAbortException) as exc_ctx:
355 fabric_env = self.default_fabric_env.copy()
356 del fabric_env['key_filename']
357 self._run(fabric_env=fabric_env)
358 assert 'Access credentials not supplied' in str(exc_ctx.value)
360 def test_hide_in_settings_and_non_viable_groups(self):
361 groups = ('running', 'stdout')
362 self._run(hide_output=groups)
363 assert set(self.mock.settings_merged['hide_output']) == set(groups)
364 with pytest.raises(TaskAbortException) as exc_ctx:
365 self._run(hide_output=('running', 'bla'))
366 assert '`hide_output` must be a subset of' in str(exc_ctx.value)
368 def test_run_commands(self):
370 commands = ['command1', 'command2']
374 assert all(item in self.mock.settings_merged.items() for
375 item in self.default_fabric_env.items())
376 assert self.mock.settings_merged['warn_only'] is True
377 assert self.mock.settings_merged['use_sudo'] == use_sudo
378 assert self.mock.commands == commands
379 self.mock.settings_merged = {}
380 self.mock.commands = []
384 def test_failed_command(self):
385 with pytest.raises(ProcessException) as exc_ctx:
386 self._run(commands=['fail'])
387 exception = exc_ctx.value
388 assert exception.stdout == self.MockCommandResult.stdout
389 assert exception.stderr == self.MockCommandResult.stderr
390 assert exception.command == self.MockCommandResult.command
391 assert exception.exit_code == self.MockCommandResult.return_code
393 class MockCommandResult(object):
394 stdout = 'mock_stdout'
395 stderr = 'mock_stderr'
396 command = 'mock_command'
399 def __init__(self, failed):
402 class MockFabricApi(object):
406 self.settings_merged = {}
408 @contextlib.contextmanager
409 def settings(self, *args, **kwargs):
410 self.settings_merged.update(kwargs)
413 self.settings_merged.update({'hide_output': groups})
416 def run(self, command):
417 self.commands.append(command)
418 self.settings_merged['use_sudo'] = False
419 return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
421 def sudo(self, command):
422 self.commands.append(command)
423 self.settings_merged['use_sudo'] = True
424 return TestFabricEnvHideGroupsAndRunCommands.MockCommandResult(command == 'fail')
426 def hide(self, *groups):
429 def exists(self, *args, **kwargs):
433 INSTRUMENTATION_FIELDS = ()
437 def abort(message=None):
438 models.Task.abort(message)
445 @contextlib.contextmanager
446 def instrument(self, *args, **kwargs):
451 logger = logging.getLogger()
454 @contextlib.contextmanager
455 def _mock_self_logging(*args, **kwargs):
457 _Ctx.logging_handlers = _mock_self_logging
459 @pytest.fixture(autouse=True)
460 def _setup(self, mocker):
461 self.default_fabric_env = {
462 'host_string': 'test',
464 'key_filename': 'test',
466 self.mock = self.MockFabricApi()
467 mocker.patch('fabric.api', self.mock)
475 operations.run_commands_with_ssh(
479 fabric_env=fabric_env or self.default_fabric_env,
481 hide_output=hide_output)
484 class TestUtilityFunctions(object):
486 def test_paths(self):
488 local_script_path = '/local/script/path.py'
489 paths = ssh_operations._Paths(base_dir=base_dir,
490 local_script_path=local_script_path)
491 assert paths.local_script_path == local_script_path
492 assert paths.remote_ctx_dir == base_dir
493 assert paths.base_script_path == 'path.py'
494 assert paths.remote_ctx_path == '/path/ctx'
495 assert paths.remote_scripts_dir == '/path/scripts'
496 assert paths.remote_work_dir == '/path/work'
497 assert paths.remote_env_script_path.startswith('/path/scripts/env-path.py-')
498 assert paths.remote_script_path.startswith('/path/scripts/path.py-')
500 def test_write_environment_script_file(self):
502 local_script_path = '/local/script/path.py'
503 paths = ssh_operations._Paths(base_dir=base_dir,
504 local_script_path=local_script_path)
506 local_socket_url = 'local_socket_url'
507 remote_socket_url = 'remote_socket_url'
508 env_script_lines = set([l for l in ssh_operations._write_environment_script_file(
509 process={'env': env},
511 local_socket_url=local_socket_url,
512 remote_socket_url=remote_socket_url
513 ).getvalue().split('\n') if l])
514 expected_env_script_lines = set([
515 'export PATH=/path:$PATH',
516 'export PYTHONPATH=/path:$PYTHONPATH',
517 'chmod +x /path/ctx',
518 'chmod +x {0}'.format(paths.remote_script_path),
519 'export CTX_SOCKET_URL={0}'.format(remote_socket_url),
520 'export LOCAL_CTX_SOCKET_URL={0}'.format(local_socket_url),
523 assert env_script_lines == expected_env_script_lines