Merge "[WIP]Migration to new policy api"
[optf/osdf.git] / runtime / model_api.py
1 # -------------------------------------------------------------------------
2 #   Copyright (c) 2020 AT&T Intellectual Property
3 #
4 #   Licensed under the Apache License, Version 2.0 (the "License");
5 #   you may not use this file except in compliance with the License.
6 #   You may obtain a copy of the License at
7 #
8 #       http://www.apache.org/licenses/LICENSE-2.0
9 #
10 #   Unless required by applicable law or agreed to in writing, software
11 #   distributed under the License is distributed on an "AS IS" BASIS,
12 #   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 #   See the License for the specific language governing permissions and
14 #   limitations under the License.
15 #
16 # -------------------------------------------------------------------------
17 #
18
19 import json
20 import traceback
21
22 import mysql.connector
23 from flask import g, Flask, Response
24
25 from osdf.config.base import osdf_config
26 from osdf.logging.osdf_logging import debug_log, error_log
27 from osdf.operation.exceptions import BusinessException
28
29
30 def init_db():
31     if is_db_enabled():
32         get_db()
33
34
35 def get_db():
36     """Opens a new database connection if there is none yet for the
37         current application context. 
38     """
39     if not hasattr(g, 'pg'):
40         properties = osdf_config['deployment']
41         host, db_port, db = properties["osdfDatabaseHost"], properties["osdfDatabasePort"], \
42                             properties.get("osdfDatabaseSchema")
43         user, password = properties["osdfDatabaseUsername"], properties["osdfDatabasePassword"]
44         g.pg = mysql.connector.connect(host=host, port=db_port, user=user, password=password, database=db)
45     return g.pg
46
47
48 def close_db():
49     """Closes the database again at the end of the request."""
50     if hasattr(g, 'pg'):
51         g.pg.close()
52
53
54 app = Flask(__name__)
55
56
57 def create_model_data(model_api):
58     with app.app_context():
59         try:
60             model_info = model_api['modelInfo']
61             model_id = model_info['modelId']
62             debug_log.debug(
63                 "persisting model_api {}".format(model_id))
64             connection = get_db()
65             cursor = connection.cursor(buffered=True)
66             query = "SELECT model_id FROM optim_model_data WHERE model_id = %s"
67             values = (model_id,)
68             cursor.execute(query, values)
69             if cursor.fetchone() is None:
70                 query = "INSERT INTO optim_model_data (model_id, model_content, description, solver_type) VALUES " \
71                         "(%s, %s, %s, %s)"
72                 values = (model_id, model_info['modelContent'], model_info.get('description'), model_info['solver'])
73                 cursor.execute(query, values)
74                 g.pg.commit()
75
76                 debug_log.debug("A record successfully inserted for request_id: {}".format(model_id))
77                 return retrieve_model_data(model_id)
78                 close_db()
79             else:
80                 query = "UPDATE optim_model_data SET model_content = %s, description = %s, solver_type = %s where " \
81                         "model_id = %s "
82                 values = (model_info['modelContent'], model_info.get('description'), model_info['solver'], model_id)
83                 cursor.execute(query, values)
84                 g.pg.commit()
85
86                 return retrieve_model_data(model_id)
87                 close_db()
88         except Exception as err:
89             error_log.error("error for request_id: {} - {}".format(model_id, traceback.format_exc()))
90             close_db()
91             raise BusinessException(err)
92
93
94 def retrieve_model_data(model_id):
95     status, resp_data = get_model_data(model_id)
96
97     if status == 200:
98         resp = json.dumps(build_model_dict(resp_data))
99         return build_response(resp, status)
100     else:
101         resp = json.dumps({
102             'modelId': model_id,
103             'statusMessage': "Error retrieving the model data for model {} due to {}".format(model_id, resp_data)
104         })
105         return build_response(resp, status)
106
107
108 def build_model_dict(resp_data, content_needed=True):
109     resp = {'modelId': resp_data[0], 'description': resp_data[2] if resp_data[2] else '',
110             'solver': resp_data[3]}
111     if content_needed:
112         resp.update({'modelContent': resp_data[1]})
113     return resp
114
115
116 def build_response(resp, status):
117     response = Response(resp, content_type='application/json; charset=utf-8')
118     response.headers.add('content-length', len(resp))
119     response.status_code = status
120     return response
121
122
123 def delete_model_data(model_id):
124     with app.app_context():
125         try:
126             debug_log.debug("deleting model data given model_id = {}".format(model_id))
127             d = dict();
128             connection = get_db()
129             cursor = connection.cursor(buffered=True)
130             query = "delete from optim_model_data WHERE model_id = %s"
131             values = (model_id,)
132             cursor.execute(query, values)
133             g.pg.commit()
134             close_db()
135             resp = {
136                 "statusMessage": "model data for modelId {} deleted".format(model_id)
137             }
138             return build_response(json.dumps(resp), 200)
139         except Exception as err:
140             error_log.error("error deleting model_id: {} - {}".format(model_id, traceback.format_exc()))
141             close_db()
142             raise BusinessException(err)
143
144
145 def get_model_data(model_id):
146     with app.app_context():
147         try:
148             debug_log.debug("getting model data given model_id = {}".format(model_id))
149             d = dict();
150             connection = get_db()
151             cursor = connection.cursor(buffered=True)
152             query = "SELECT model_id, model_content, description, solver_type  FROM optim_model_data WHERE model_id = %s"
153             values = (model_id,)
154             cursor.execute(query, values)
155             if cursor is None:
156                 return 400, "FAILED"
157             else:
158                 rows = cursor.fetchone()
159                 if rows is not None:
160                     index = 0
161                     for row in rows:
162                         d[index] = row
163                         index = index + 1
164                     return 200, d
165                 else:
166                     close_db()
167                     return 500, "NOT_FOUND"
168         except Exception:
169             error_log.error("error for request_id: {} - {}".format(model_id, traceback.format_exc()))
170             close_db()
171             return 500, "FAILED"
172
173
174 def retrieve_all_models():
175     status, resp_data = get_all_models()
176     model_list = []
177     if status == 200:
178         for r in resp_data:
179             model_list.append(build_model_dict(r, False))
180         resp = json.dumps(model_list)
181         return build_response(resp, status)
182
183     else:
184         resp = json.dumps({
185             'statusMessage': "Error retrieving all the model data due to {}".format(resp_data)
186         })
187         return build_response(resp, status)
188
189
190 def get_all_models():
191     with app.app_context():
192         try:
193             debug_log.debug("getting all model data".format())
194             connection = get_db()
195             cursor = connection.cursor(buffered=True)
196             query = "SELECT model_id, model_content, description, solver_type  FROM optim_model_data"
197     
198             cursor.execute(query)
199             if cursor is None:
200                 return 400, "FAILED"
201             else:
202                 rows = cursor.fetchall()
203                 if rows is not None:
204                     return 200, rows
205                 else:
206                     close_db()
207                     return 500, "NOT_FOUND"
208         except Exception:
209             error_log.error("error for request_id:  {}".format(traceback.format_exc()))
210             close_db()
211             return 500, "FAILED"
212
213
214 def is_db_enabled():
215     return osdf_config['deployment'].get('isDatabaseEnabled', False)