Adding the generic solver code
[optf/osdf.git] / runtime / model_api.py
diff --git a/runtime/model_api.py b/runtime/model_api.py
new file mode 100644 (file)
index 0000000..fd87333
--- /dev/null
@@ -0,0 +1,215 @@
+# -------------------------------------------------------------------------
+#   Copyright (c) 2020 AT&T Intellectual Property
+#
+#   Licensed under the Apache License, Version 2.0 (the "License");
+#   you may not use this file except in compliance with the License.
+#   You may obtain a copy of the License at
+#
+#       http://www.apache.org/licenses/LICENSE-2.0
+#
+#   Unless required by applicable law or agreed to in writing, software
+#   distributed under the License is distributed on an "AS IS" BASIS,
+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#   See the License for the specific language governing permissions and
+#   limitations under the License.
+#
+# -------------------------------------------------------------------------
+#
+
+import json
+import traceback
+
+import mysql.connector
+from flask import g, Flask, Response
+
+from osdf.config.base import osdf_config
+from osdf.logging.osdf_logging import debug_log, error_log
+from osdf.operation.exceptions import BusinessException
+
+
+def init_db():
+    if is_db_enabled():
+        get_db()
+
+
+def get_db():
+    """Opens a new database connection if there is none yet for the
+        current application context. 
+    """
+    if not hasattr(g, 'pg'):
+        properties = osdf_config['deployment']
+        host, db_port, db = properties["osdfDatabaseHost"], properties["osdfDatabasePort"], \
+                            properties.get("osdfDatabaseSchema")
+        user, password = properties["osdfDatabaseUsername"], properties["osdfDatabasePassword"]
+        g.pg = mysql.connector.connect(host=host, port=db_port, user=user, password=password, database=db)
+    return g.pg
+
+
+def close_db():
+    """Closes the database again at the end of the request."""
+    if hasattr(g, 'pg'):
+        g.pg.close()
+
+
+app = Flask(__name__)
+
+
+def create_model_data(model_api):
+    with app.app_context():
+        try:
+            model_info = model_api['modelInfo']
+            model_id = model_info['modelId']
+            debug_log.debug(
+                "persisting model_api {}".format(model_id))
+            connection = get_db()
+            cursor = connection.cursor(buffered=True)
+            query = "SELECT model_id FROM optim_model_data WHERE model_id = %s"
+            values = (model_id,)
+            cursor.execute(query, values)
+            if cursor.fetchone() is None:
+                query = "INSERT INTO optim_model_data (model_id, model_content, description, solver_type) VALUES " \
+                        "(%s, %s, %s, %s)"
+                values = (model_id, model_info['modelContent'], model_info.get('description'), model_info['solver'])
+                cursor.execute(query, values)
+                g.pg.commit()
+
+                debug_log.debug("A record successfully inserted for request_id: {}".format(model_id))
+                return retrieve_model_data(model_id)
+                close_db()
+            else:
+                query = "UPDATE optim_model_data SET model_content = %s, description = %s, solver_type = %s where " \
+                        "model_id = %s "
+                values = (model_info['modelContent'], model_info.get('description'), model_info['solver'], model_id)
+                cursor.execute(query, values)
+                g.pg.commit()
+
+                return retrieve_model_data(model_id)
+                close_db()
+        except Exception as err:
+            error_log.error("error for request_id: {} - {}".format(model_id, traceback.format_exc()))
+            close_db()
+            raise BusinessException(err)
+
+
+def retrieve_model_data(model_id):
+    status, resp_data = get_model_data(model_id)
+
+    if status == 200:
+        resp = json.dumps(build_model_dict(resp_data))
+        return build_response(resp, status)
+    else:
+        resp = json.dumps({
+            'modelId': model_id,
+            'statusMessage': "Error retrieving the model data for model {} due to {}".format(model_id, resp_data)
+        })
+        return build_response(resp, status)
+
+
+def build_model_dict(resp_data, content_needed=True):
+    resp = {'modelId': resp_data[0], 'description': resp_data[2] if resp_data[2] else '',
+            'solver': resp_data[3]}
+    if content_needed:
+        resp.update({'modelContent': resp_data[1]})
+    return resp
+
+
+def build_response(resp, status):
+    response = Response(resp, content_type='application/json; charset=utf-8')
+    response.headers.add('content-length', len(resp))
+    response.status_code = status
+    return response
+
+
+def delete_model_data(model_id):
+    with app.app_context():
+        try:
+            debug_log.debug("deleting model data given model_id = {}".format(model_id))
+            d = dict();
+            connection = get_db()
+            cursor = connection.cursor(buffered=True)
+            query = "delete from optim_model_data WHERE model_id = %s"
+            values = (model_id,)
+            cursor.execute(query, values)
+            g.pg.commit()
+            close_db()
+            resp = {
+                "statusMessage": "model data for modelId {} deleted".format(model_id)
+            }
+            return build_response(json.dumps(resp), 200)
+        except Exception as err:
+            error_log.error("error deleting model_id: {} - {}".format(model_id, traceback.format_exc()))
+            close_db()
+            raise BusinessException(err)
+
+
+def get_model_data(model_id):
+    with app.app_context():
+        try:
+            debug_log.debug("getting model data given model_id = {}".format(model_id))
+            d = dict();
+            connection = get_db()
+            cursor = connection.cursor(buffered=True)
+            query = "SELECT model_id, model_content, description, solver_type  FROM optim_model_data WHERE model_id = %s"
+            values = (model_id,)
+            cursor.execute(query, values)
+            if cursor is None:
+                return 400, "FAILED"
+            else:
+                rows = cursor.fetchone()
+                if rows is not None:
+                    index = 0
+                    for row in rows:
+                        d[index] = row
+                        index = index + 1
+                    return 200, d
+                else:
+                    close_db()
+                    return 500, "NOT_FOUND"
+        except Exception:
+            error_log.error("error for request_id: {} - {}".format(model_id, traceback.format_exc()))
+            close_db()
+            return 500, "FAILED"
+
+
+def retrieve_all_models():
+    status, resp_data = get_all_models()
+    model_list = []
+    if status == 200:
+        for r in resp_data:
+            model_list.append(build_model_dict(r, False))
+        resp = json.dumps(model_list)
+        return build_response(resp, status)
+
+    else:
+        resp = json.dumps({
+            'statusMessage': "Error retrieving all the model data due to {}".format(resp_data)
+        })
+        return build_response(resp, status)
+
+
+def get_all_models():
+    with app.app_context():
+        try:
+            debug_log.debug("getting all model data".format())
+            connection = get_db()
+            cursor = connection.cursor(buffered=True)
+            query = "SELECT model_id, model_content, description, solver_type  FROM optim_model_data"
+    
+            cursor.execute(query)
+            if cursor is None:
+                return 400, "FAILED"
+            else:
+                rows = cursor.fetchall()
+                if rows is not None:
+                    return 200, rows
+                else:
+                    close_db()
+                    return 500, "NOT_FOUND"
+        except Exception:
+            error_log.error("error for request_id:  {}".format(traceback.format_exc()))
+            close_db()
+            return 500, "FAILED"
+
+
+def is_db_enabled():
+    return osdf_config['deployment'].get('isDatabaseEnabled', False)