[VVP] Removing unnecessary trademark lines
[vvp/validation-scripts.git] / ice_validator / tests / helpers.py
1 # -*- coding: utf8 -*-
2 # ============LICENSE_START====================================================
3 # org.onap.vvp/validation-scripts
4 # ===================================================================
5 # Copyright © 2019 AT&T Intellectual Property. All rights reserved.
6 # ===================================================================
7 #
8 # Unless otherwise specified, all software contained herein is licensed
9 # under the Apache License, Version 2.0 (the "License");
10 # you may not use this software except in compliance with the License.
11 # You may obtain a copy of the License at
12 #
13 #             http://www.apache.org/licenses/LICENSE-2.0
14 #
15 # Unless required by applicable law or agreed to in writing, software
16 # distributed under the License is distributed on an "AS IS" BASIS,
17 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18 # See the License for the specific language governing permissions and
19 # limitations under the License.
20 #
21 #
22 #
23 # Unless otherwise specified, all documentation contained herein is licensed
24 # under the Creative Commons License, Attribution 4.0 Intl. (the "License");
25 # you may not use this documentation except in compliance with the License.
26 # You may obtain a copy of the License at
27 #
28 #             https://creativecommons.org/licenses/by/4.0/
29 #
30 # Unless required by applicable law or agreed to in writing, documentation
31 # distributed under the License is distributed on an "AS IS" BASIS,
32 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
33 # See the License for the specific language governing permissions and
34 # limitations under the License.
35 #
36 # ============LICENSE_END============================================
37 #
38 #
39
40 """Helpers
41 """
42
43 import os
44 from collections import defaultdict
45
46 from boltons import funcutils
47 from tests import cached_yaml as yaml
48
49 VERSION = "1.1.0"
50
51
52 def check_basename_ending(template_type, basename):
53     """
54     return True/False if the template type is matching
55     the filename
56     """
57     if not template_type:
58         return True
59     elif template_type == "volume":
60         return basename.endswith("_volume")
61     else:
62         return not basename.endswith("_volume")
63
64
65 def get_parsed_yml_for_yaml_files(yaml_files, sections=None):
66     """
67     get the parsed yaml for a list of yaml files
68     """
69     sections = [] if sections is None else sections
70     parsed_yml_list = []
71     for yaml_file in yaml_files:
72         try:
73             with open(yaml_file) as fh:
74                 yml = yaml.load(fh)
75         except yaml.YAMLError as e:
76             # pylint: disable=superfluous-parens
77             print("Error in %s: %s" % (yaml_file, e))
78             continue
79         if yml:
80             if sections:
81                 for k in yml.keys():
82                     if k not in sections:
83                         del yml[k]
84             parsed_yml_list.append(yml)
85     return parsed_yml_list
86
87
88 def validates(*requirement_ids):
89     """Decorator that tags the test function with one or more requirement IDs.
90
91     Example:
92         >>> @validates('R-12345', 'R-12346')
93         ... def test_something():
94         ...     pass
95         >>> assert test_something.requirement_ids == ['R-12345', 'R-12346']
96     """
97     # pylint: disable=missing-docstring
98     def decorator(func):
99         # NOTE: We use a utility here to ensure that function signatures are
100         # maintained because pytest inspects function signatures to inject
101         # fixtures.  I experimented with a few options, but this is the only
102         # library that worked. Other libraries dynamically generated a
103         # function at run-time, and then lost the requirement_ids attribute
104         @funcutils.wraps(func)
105         def wrapper(*args, **kw):
106             return func(*args, **kw)
107
108         wrapper.requirement_ids = requirement_ids
109         return wrapper
110
111     decorator.requirement_ids = requirement_ids
112     return decorator
113
114
115 def categories(*categories):
116     def decorator(func):
117         @funcutils.wraps(func)
118         def wrapper(*args, **kw):
119             return func(*args, **kw)
120
121         wrapper.categories = categories
122         return wrapper
123
124     decorator.categories = categories
125     return decorator
126
127
128 def get_environment_pair(heat_template):
129     """Returns a yaml/env pair given a yaml file"""
130     base_dir, filename = os.path.split(heat_template)
131     basename = os.path.splitext(filename)[0]
132     env_template = os.path.join(base_dir, "{}.env".format(basename))
133     if os.path.exists(env_template):
134         with open(heat_template, "r") as fh:
135             yyml = yaml.load(fh)
136         with open(env_template, "r") as fh:
137             eyml = yaml.load(fh)
138
139         environment_pair = {"name": basename, "yyml": yyml, "eyml": eyml}
140         return environment_pair
141
142     return None
143
144
145 def find_environment_file(yaml_files):
146     """
147     Pass file and recursively step backwards until environment file is found
148
149     :param yaml_files: list or string, start at size 1 and grows recursively
150     :return: corresponding environment file for a file, or None
151     """
152     # sanitize
153     if isinstance(yaml_files, str):
154         yaml_files = [yaml_files]
155
156     yaml_file = yaml_files[-1]
157     filepath, filename = os.path.split(yaml_file)
158
159     environment_pair = get_environment_pair(yaml_file)
160     if environment_pair:
161         return environment_pair
162
163     for file in os.listdir(filepath):
164         fq_name = "{}/{}".format(filepath, file)
165         if fq_name.endswith("yaml") or fq_name.endswith("yml"):
166             if fq_name not in yaml_files:
167                 with open(fq_name) as f:
168                     yml = yaml.load(f)
169                 resources = yml.get("resources", {})
170                 for resource_id, resource in resources.items():
171                     resource_type = resource.get("type", "")
172                     if resource_type == "OS::Heat::ResourceGroup":
173                         resource_type = (
174                             resource.get("properties", {})
175                             .get("resource_def", {})
176                             .get("type", "")
177                         )
178                     # found called nested file
179                     if resource_type == filename:
180                         yaml_files.append(fq_name)
181                         environment_pair = find_environment_file(yaml_files)
182
183     return environment_pair
184
185
186 def load_yaml(yaml_file):
187     """
188     Load the YAML file at the given path.  If the file has previously been
189     loaded, then a cached version will be returned.
190
191     :param yaml_file: path to the YAML file
192     :return: data structure loaded from the YAML file
193     """
194     with open(yaml_file) as fh:
195         return yaml.load(fh)
196
197
198 def traverse(data, search_key, func, path=None):
199     """
200     Traverse the data structure provided via ``data`` looking for occurences
201     of ``search_key``.  When ``search_key`` is found, the value associated
202     with that key is passed to ``func``
203
204     :param data:        arbitrary data structure of dicts and lists
205     :param search_key:  key field to search for
206     :param func:        Callable object that takes two parameters:
207                         * A list representing the path of keys to search_key
208                         * The value associated with the search_key
209     """
210     path = [] if path is None else path
211     if isinstance(data, dict):
212         for key, value in data.items():
213             curr_path = path + [key]
214             if key == search_key:
215                 func(curr_path, value)
216             traverse(value, search_key, func, curr_path)
217     elif isinstance(data, list):
218         for value in data:
219             curr_path = path + [value]
220             if isinstance(value, dict):
221                 traverse(value, search_key, func, curr_path)
222             elif value == search_key:
223                 func(curr_path, value)
224
225
226 def check_indices(pattern, values, value_type):
227     """
228     Checks that indices associated with the matched prefix start at 0 and
229     increment by 1.  It returns a list of messages for any prefixes that
230     violate the rules.
231
232     :param pattern: Compiled regex that whose first group matches the prefix and
233                     second group matches the index
234     :param values:  sequence of string names that may or may not match the pattern
235     :param name:    Type of value being checked (ex: IP Parameters). This will
236                     be included in the error messages.
237     :return:        List of error messages, empty list if no violations found
238     """
239     if not hasattr(pattern, "match"):
240         raise RuntimeError("Pattern must be a compiled regex")
241
242     prefix_indices = defaultdict(set)
243     for value in values:
244         m = pattern.match(value)
245         if m:
246             prefix_indices[m.group(1)].add(int(m.group(2)))
247
248     invalid_params = []
249     for prefix, indices in prefix_indices.items():
250         indices = sorted(indices)
251         if indices[0] != 0:
252             invalid_params.append(
253                 "{} with prefix {} do not start at 0".format(value_type, prefix)
254             )
255         elif len(indices) - 1 != indices[-1]:
256             invalid_params.append(
257                 (
258                     "Index values of {} with prefix {} do not " + "increment by 1: {}"
259                 ).format(value_type, prefix, indices)
260             )
261     return invalid_params