1 # Licensed to the Apache Software Foundation (ASF) under one or more
2 # contributor license agreements. See the NOTICE file distributed with
3 # this work for additional information regarding copyright ownership.
4 # The ASF licenses this file to You under the Apache License, Version 2.0
5 # (the "License"); you may not use this file except in compliance with
6 # the License. You may obtain a copy of the License at
8 # http://www.apache.org/licenses/LICENSE-2.0
10 # Unless required by applicable law or agreed to in writing, software
11 # distributed under the License is distributed on an "AS IS" BASIS,
12 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 # See the License for the specific language governing permissions and
14 # limitations under the License.
17 from StringIO import StringIO
27 return utils.MockStorage()
30 @pytest.mark.usefixtures("redirect_logger")
31 class TestCliBase(object):
34 @pytest.fixture(scope="class")
35 def redirect_logger():
37 utils.setup_logger(logger_name='aria.cli.main',
38 handlers=[logging.StreamHandler(TestCliBase._logger_output)],
39 logger_format='%(message)s')
41 utils.setup_logger(logger_name='aria.cli.main',
42 handlers=_default_logger_config['handlers'],
43 level=_default_logger_config['level'])
45 _logger_output = StringIO()
47 def invoke(self, command):
48 self._logger_output.truncate(0)
49 return runner.invoke(command)
52 def logger_output_string(self):
53 return self._logger_output.getvalue()
56 def assert_exception_raised(outcome, expected_exception, expected_msg=''):
57 assert isinstance(outcome.exception, expected_exception)
58 assert expected_msg in str(outcome.exception)
61 # This exists as I wanted to mocked a function using monkeypatch to return a function that raises an
62 # exception. I tried doing that using a lambda in-place, but this can't be accomplished in a trivial
63 # way it seems. So I wrote this silly function instead
64 def raise_exception(exception, msg=''):
66 def inner(*args, **kwargs):
72 def get_default_logger_config():
73 logger = logging.getLogger('aria.cli.main')
74 return {'handlers': logger.handlers,
75 'level': logger.level}
77 _default_logger_config = get_default_logger_config()