Make the jax-rs resource a spring boot RestController
[aai/traversal.git] / aai-traversal / src / main / java / org / onap / aai / rest / DslConsumer.java
index d814b48..2a02ff5 100644 (file)
@@ -22,34 +22,24 @@ package org.onap.aai.rest;
 
 import java.io.FileNotFoundException;
 import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
 import javax.servlet.http.HttpServletRequest;
-import javax.ws.rs.Consumes;
-import javax.ws.rs.DefaultValue;
-import javax.ws.rs.PUT;
-import javax.ws.rs.Path;
-import javax.ws.rs.PathParam;
-import javax.ws.rs.Produces;
-import javax.ws.rs.QueryParam;
-import javax.ws.rs.core.Context;
-import javax.ws.rs.core.HttpHeaders;
-import javax.ws.rs.core.MediaType;
 import javax.ws.rs.core.MultivaluedHashMap;
 import javax.ws.rs.core.MultivaluedMap;
-import javax.ws.rs.core.Response;
-import javax.ws.rs.core.Response.Status;
-import javax.ws.rs.core.UriInfo;
 
+import org.antlr.v4.runtime.tree.ParseTreeListener;
 import org.apache.tinkerpop.gremlin.process.traversal.dsl.graph.GraphTraversalSource;
-import org.janusgraph.core.SchemaViolationException;
-import org.onap.aai.concurrent.AaiCallable;
+import org.onap.aai.edges.EdgeIngestor;
 import org.onap.aai.exceptions.AAIException;
+import org.onap.aai.introspection.LoaderFactory;
 import org.onap.aai.introspection.ModelType;
 import org.onap.aai.rest.db.HttpEntry;
 import org.onap.aai.rest.dsl.DslQueryProcessor;
@@ -57,7 +47,6 @@ import org.onap.aai.rest.enums.QueryVersion;
 import org.onap.aai.rest.search.GenericQueryProcessor;
 import org.onap.aai.rest.search.GremlinServerSingleton;
 import org.onap.aai.rest.search.QueryProcessorType;
-import org.onap.aai.restcore.HttpMethod;
 import org.onap.aai.serialization.db.DBSerializer;
 import org.onap.aai.serialization.engines.TransactionalGraphEngine;
 import org.onap.aai.serialization.queryformats.Format;
@@ -73,6 +62,16 @@ import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 import org.springframework.beans.factory.annotation.Autowired;
 import org.springframework.beans.factory.annotation.Value;
+import org.springframework.http.HttpHeaders;
+import org.springframework.http.MediaType;
+import org.springframework.http.ResponseEntity;
+import org.springframework.web.bind.annotation.PathVariable;
+import org.springframework.web.bind.annotation.PutMapping;
+import org.springframework.web.bind.annotation.RequestBody;
+import org.springframework.web.bind.annotation.RequestHeader;
+import org.springframework.web.bind.annotation.RequestMapping;
+import org.springframework.web.bind.annotation.RequestParam;
+import org.springframework.web.bind.annotation.RestController;
 
 import com.google.gson.JsonElement;
 import com.google.gson.JsonObject;
@@ -81,7 +80,8 @@ import com.google.gson.JsonParser;
 import io.micrometer.core.annotation.Timed;
 
 @Timed
-@Path("{version: v[1-9][0-9]*|latest}/dsl")
+@RestController
+@RequestMapping("/{version:v[1-9][0-9]*|latest}/dsl")
 public class DslConsumer extends TraversalConsumer {
 
     private static final Logger LOGGER = LoggerFactory.getLogger(DslConsumer.class);
@@ -89,55 +89,58 @@ public class DslConsumer extends TraversalConsumer {
     private static final QueryVersion DEFAULT_VERSION = QueryVersion.V1;
 
     private final HttpEntry traversalUriHttpEntry;
-    private final DslQueryProcessor dslQueryProcessor;
     private final SchemaVersions schemaVersions;
     private final String basePath;
     private final GremlinServerSingleton gremlinServerSingleton;
     private final XmlFormatTransformer xmlFormatTransformer;
+    // private final Map<QueryVersion, ParseTreeListener> dslListeners;
+    private final EdgeIngestor edgeIngestor;
+    private final LoaderFactory loaderFactory;
 
     private QueryVersion dslApiVersion = DEFAULT_VERSION;
 
     @Autowired
-    public DslConsumer(HttpEntry traversalUriHttpEntry, DslQueryProcessor dslQueryProcessor,
+    public DslConsumer(HttpEntry traversalUriHttpEntry,
             SchemaVersions schemaVersions, GremlinServerSingleton gremlinServerSingleton,
             XmlFormatTransformer xmlFormatTransformer,
+            EdgeIngestor edgeIngestor, LoaderFactory loaderFactory,
             @Value("${schema.uri.base.path}") String basePath) {
         this.traversalUriHttpEntry = traversalUriHttpEntry;
-        this.dslQueryProcessor = dslQueryProcessor;
         this.schemaVersions = schemaVersions;
         this.gremlinServerSingleton = gremlinServerSingleton;
         this.xmlFormatTransformer = xmlFormatTransformer;
         this.basePath = basePath;
+        this.edgeIngestor = edgeIngestor;
+        this.loaderFactory = loaderFactory;
     }
 
-    @PUT
-    @Consumes({MediaType.APPLICATION_JSON})
-    @Produces({MediaType.APPLICATION_JSON, MediaType.APPLICATION_XML})
-    public Response executeQuery(String dslQuery, @PathParam("version") String versionParam,
-            @DefaultValue("graphson") @QueryParam("format") String queryFormat,
-            @DefaultValue("no_op") @QueryParam("subgraph") String subgraph,
-            @DefaultValue("all") @QueryParam("validate") String validate,
-            @DefaultValue("-1") @QueryParam("resultIndex") String resultIndex,
-            @DefaultValue("-1") @QueryParam("resultSize") String resultSize,
-            @Context HttpHeaders headers,
-            @Context HttpServletRequest req,
-            @Context UriInfo info) throws FileNotFoundException, AAIException {
-        Set<String> roles = this.getRoles(req.getUserPrincipal());
-
-        return processExecuteQuery(dslQuery, req, versionParam, queryFormat, subgraph,
-                validate, headers, info, resultIndex, resultSize, roles);
+    @PutMapping(produces = {MediaType.APPLICATION_JSON_VALUE, MediaType.APPLICATION_XML_VALUE})
+    public ResponseEntity<String> executeQuery(@RequestBody String dslQuery,
+                                               @PathVariable("version") String versionParam,
+                                               @RequestParam(defaultValue = "graphson") String format,
+                                               @RequestParam(defaultValue = "no_op") String subgraph,
+                                               @RequestParam(defaultValue = "all") String validate,
+                                               @RequestParam(defaultValue = "-1") String resultIndex,
+                                               @RequestParam(defaultValue = "-1") String resultSize,
+                                               @RequestHeader HttpHeaders headers,
+                                               HttpServletRequest request) throws FileNotFoundException, AAIException {
+        Set<String> roles = this.getRoles(request.getUserPrincipal());
+
+        return processExecuteQuery(dslQuery, request, versionParam, format, subgraph,
+                validate, headers, resultIndex, resultSize, roles);
     }
 
-    public Response processExecuteQuery(String dslQuery, HttpServletRequest request, String versionParam,
-            String queryFormat, String subgraph, String validate, HttpHeaders headers, UriInfo info,
+    public ResponseEntity<String> processExecuteQuery(String dslQuery, HttpServletRequest request, String versionParam,
+            String queryFormat, String subgraph, String validate, HttpHeaders headers,
             String resultIndex, String resultSize, Set<String> roles) throws FileNotFoundException, AAIException {
 
         final SchemaVersion version = new SchemaVersion(versionParam);
-        final String sourceOfTruth = headers.getRequestHeaders().getFirst("X-FromAppId");
-        final String dslOverride = headers.getRequestHeaders().getFirst("X-DslOverride");
+        final String sourceOfTruth = headers.getFirst("X-FromAppId");
+        final String dslOverride = headers.getFirst("X-DslOverride");
+        final MultivaluedMap<String,String> queryParams = toMultivaluedMap(request.getParameterMap());
 
         Optional<String> dslApiVersionHeader =
-            Optional.ofNullable(headers.getRequestHeaders().getFirst("X-DslApiVersion"));
+            Optional.ofNullable(headers.getFirst("X-DslApiVersion"));
         if (dslApiVersionHeader.isPresent()) {
             try {
                 dslApiVersion = QueryVersion.valueOf(dslApiVersionHeader.get());
@@ -146,25 +149,25 @@ public class DslConsumer extends TraversalConsumer {
             }
         }
 
-        String result = executeQuery(dslQuery, request, queryFormat, subgraph, validate, info.getQueryParameters(), resultIndex, resultSize,
+        String result = executeQuery(dslQuery, request, queryFormat, subgraph, validate, queryParams, resultIndex, resultSize,
                 roles, version, sourceOfTruth, dslOverride);
+        MediaType acceptType = headers.getAccept().stream()
+            .filter(Objects::nonNull)
+            .filter(header -> !header.equals(MediaType.ALL))
+            .findAny()
+            .orElse(MediaType.APPLICATION_JSON);
 
-        String acceptType = headers.getHeaderString("Accept");
-        if (acceptType == null) {
-            acceptType = MediaType.APPLICATION_JSON;
-        }
-
-        if (MediaType.APPLICATION_XML_TYPE.isCompatible(MediaType.valueOf(acceptType))) {
+        if (MediaType.APPLICATION_XML.isCompatibleWith(acceptType)) {
             result = xmlFormatTransformer.transform(result);
         }
 
         if (traversalUriHttpEntry.isPaginated()) {
-            return Response.status(Status.OK).type(acceptType)
-                    .header("total-results", traversalUriHttpEntry.getTotalVertices())
-                    .header("total-pages", traversalUriHttpEntry.getTotalPaginationBuckets())
-                    .entity(result).build();
+            return ResponseEntity.ok()
+                .header("total-results", String.valueOf(traversalUriHttpEntry.getTotalVertices()))
+                .header("total-pages", String.valueOf(traversalUriHttpEntry.getTotalPaginationBuckets()))
+                .body(result);
         } else {
-            return Response.status(Status.OK).type(acceptType).entity(result).build();
+            return ResponseEntity.ok(result);
         }
     }
 
@@ -176,7 +179,6 @@ public class DslConsumer extends TraversalConsumer {
             req.getRequestURL().toString().replaceAll("/(v[0-9]+|latest)/.*", "/");
         traversalUriHttpEntry.setHttpEntryProperties(version, serverBase);
         traversalUriHttpEntry.setPaginationParameters(resultIndex, resultSize);
-        final TransactionalGraphEngine dbEngine = traversalUriHttpEntry.getDbEngine();
 
         JsonObject input = JsonParser.parseString(content).getAsJsonObject();
         JsonElement dslElement = input.get("dsl");
@@ -189,6 +191,12 @@ public class DslConsumer extends TraversalConsumer {
                 && !AAIConfig.get(TraversalConstants.DSL_OVERRIDE).equals("false")
                 && dslOverride.equals(AAIConfig.get(TraversalConstants.DSL_OVERRIDE));
 
+        Map<QueryVersion, ParseTreeListener> dslListeners = new HashMap<>();
+        dslListeners.put(QueryVersion.V1,
+            new org.onap.aai.rest.dsl.v1.DslListener(edgeIngestor, schemaVersions, loaderFactory));
+        dslListeners.put(QueryVersion.V2,
+            new org.onap.aai.rest.dsl.v2.DslListener(edgeIngestor, schemaVersions, loaderFactory));
+        DslQueryProcessor dslQueryProcessor = new DslQueryProcessor(dslListeners);
         if (isDslOverride) {
             dslQueryProcessor.setStartNodeValidationFlag(false);
         }
@@ -205,9 +213,10 @@ public class DslConsumer extends TraversalConsumer {
             validateHistoryParams(format, queryParameters);
         }
 
+        final TransactionalGraphEngine dbEngine = traversalUriHttpEntry.getDbEngine();
         GraphTraversalSource traversalSource =
             getTraversalSource(dbEngine, format, queryParameters, roles);
-
+        
         GenericQueryProcessor processor =
             new GenericQueryProcessor.Builder(dbEngine, gremlinServerSingleton)
                 .queryFrom(dsl, "dsl").queryProcessor(dslQueryProcessor).version(dslApiVersion)
@@ -217,12 +226,11 @@ public class DslConsumer extends TraversalConsumer {
         SubGraphStyle subGraphStyle = SubGraphStyle.valueOf(subgraph);
         List<Object> vertTemp = processor.execute(subGraphStyle);
 
-        // Dedup if duplicate objects are returned in each array in the aggregate format
-        // scenario.
-        List<Object> vertTempDedupedObjectList = dedupObjectInAggregateFormatResult(vertTemp);
-
         List<Object> vertices;
         if (isAggregate(format)) {
+            // Dedup if duplicate objects are returned in each array in the aggregate format
+            // scenario.
+            List<Object> vertTempDedupedObjectList = dedupObjectInAggregateFormatResult(vertTemp);
             vertices = traversalUriHttpEntry
                     .getPaginatedVertexListForAggregateFormat(vertTempDedupedObjectList);
         } else {
@@ -252,6 +260,12 @@ public class DslConsumer extends TraversalConsumer {
         return result;
     }
 
+    private List<Object> dedupObjectInAggregateFormatResultStreams(List<Object> vertTemp) {
+        return vertTemp.stream()
+            .filter(o -> o instanceof ArrayList)
+            .map(o -> ((ArrayList<?>) o).stream().distinct().collect(Collectors.toList()))
+            .collect(Collectors.toList());
+    }
     private List<Object> dedupObjectInAggregateFormatResult(List<Object> vertTemp) {
         List<Object> vertTempDedupedObjectList = new ArrayList<Object>();
         Iterator<Object> itr = vertTemp.listIterator();
@@ -264,4 +278,15 @@ public class DslConsumer extends TraversalConsumer {
         }
         return vertTempDedupedObjectList;
     }
+
+    private MultivaluedMap<String, String> toMultivaluedMap(Map<String, String[]> map) {
+        MultivaluedMap<String, String> multivaluedMap = new MultivaluedHashMap<>();
+
+        for (Map.Entry<String, String[]> entry : map.entrySet()) {
+            for (String val : entry.getValue())
+            multivaluedMap.add(entry.getKey(), val);
+        }
+
+        return multivaluedMap;
+    }
 }