Add ML based optimization to PCI opt
[optf/osdf.git] / test / apps / pci_optimization / test_ml_model.py
1 # -------------------------------------------------------------------------
2 #   Copyright (C) 2020 Wipro Limited.
3 #
4 #   Licensed under the Apache License, Version 2.0 (the "License");
5 #   you may not use this file except in compliance with the License.
6 #   You may obtain a copy of the License at
7 #
8 #       http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #   Unless required by applicable law or agreed to in writing, software
11 #   distributed under the License is distributed on an "AS IS" BASIS,
12 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 #   See the License for the specific language governing permissions and
14 #   limitations under the License.
15 #
16 # -------------------------------------------------------------------------
17 #
18
19 import copy
20 from mock import patch
21 import unittest
22 from apps.pci.optimizers.solver.ml_model import MlModel
23 from osdf.adapters.dcae.des import DESException
24 import osdf.config.loader as config_loader
25 from osdf.utils.interfaces import json_from_file
26 from osdf.utils.programming_utils import DotDict
27
28
29 class TestMlModel(unittest.TestCase):
30     def setUp(self):
31         self.config_spec = {
32             "deployment": "config/osdf_config.yaml",
33             "core": "config/common_config.yaml"
34         }
35         self.osdf_config = DotDict(config_loader.all_configs(**self.config_spec))
36
37     def tearDown(self):
38         pass
39
40     def test_ml_model(self):
41         des_result_file = 'test/apps/pci_optimization/des_result.json'
42         results = json_from_file(des_result_file)
43
44         dzn_data = {
45             'NUM_NODES': 4,
46             'NUM_PCIS': 4,
47             'NUM_NEIGHBORS': 4,
48             'NEIGHBORS': [],
49             'NUM_SECOND_LEVEL_NEIGHBORS': 1,
50             'SECOND_LEVEL_NEIGHBORS': [],
51             'PCI_UNCHANGEABLE_CELLS': [],
52             'ORIGINAL_PCIS': []
53         }
54
55         network_cell_info = {
56             'cell_list': [
57                 {
58                     'cell_id': 'Chn0001',
59                     'id': 1,
60                     'nbr_list': []
61                 },
62                 {
63                     'cell_id': 'Chn0002',
64                     'id': 2,
65                     'nbr_list': []
66                 }
67             ]
68         }
69         self.patcher_req = patch('osdf.adapters.dcae.des.extract_data', side_effect=results)
70         self.Mock_req = self.patcher_req.start()
71         mlmodel = MlModel()
72         mlmodel.get_additional_inputs(dzn_data, network_cell_info)
73         self.assertEqual(['Chn0001'], dzn_data['PCI_UNCHANGEABLE_CELLS'])
74         self.patcher_req.stop()
75
76         dzn_data['PCI_UNCHANGEABLE_CELLS'] = []
77         self.patcher_req = patch('osdf.adapters.dcae.des.extract_data', side_effect=DESException('error'))
78         self.Mock_req = self.patcher_req.start()
79         mlmodel.get_additional_inputs(dzn_data, network_cell_info)
80         self.assertEqual([], dzn_data['PCI_UNCHANGEABLE_CELLS'])
81         self.patcher_req.stop()
82
83         self.patcher_req = patch('osdf.adapters.dcae.des.extract_data', return_value=[])
84         self.Mock_req = self.patcher_req.start()
85         mlmodel.get_additional_inputs(dzn_data, network_cell_info)
86         self.assertEqual([], dzn_data['PCI_UNCHANGEABLE_CELLS'])
87         self.patcher_req.stop()