reduce CDS java security vulnerabilities
[ccsdk/cds.git] / ms / blueprintsprocessor / functions / netconf-executor / src / main / kotlin / org / onap / ccsdk / cds / blueprintsprocessor / functions / netconf / executor / core / NetconfSessionImpl.kt
index b1121b3..31d90fd 100644 (file)
@@ -21,11 +21,9 @@ import com.google.common.collect.ImmutableSet
 import org.apache.sshd.client.SshClient
 import org.apache.sshd.client.channel.ClientChannel
 import org.apache.sshd.client.session.ClientSession
-import org.apache.sshd.common.FactoryManager
-import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider
+import org.apache.sshd.core.CoreModuleProperties
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.DeviceInfo
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfException
-import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfReceivedEvent
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfRpcService
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSession
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSessionListener
@@ -65,8 +63,10 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
 
     override fun connect() {
         try {
-            log.info("$deviceInfo: Connecting to Netconf Device with timeouts C:${deviceInfo.connectTimeout}, " +
-                    "R:${deviceInfo.replyTimeout}, I:${deviceInfo.idleTimeout}")
+            log.info(
+                "$deviceInfo: Connecting to Netconf Device with timeouts C:${deviceInfo.connectTimeout}, " +
+                    "R:${deviceInfo.replyTimeout}, I:${deviceInfo.idleTimeout}"
+            )
             startConnection()
             log.info("$deviceInfo: Connected to Netconf Device")
         } catch (e: NetconfException) {
@@ -76,8 +76,16 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
     }
 
     override fun disconnect() {
-        if (rpcService.closeSession(false).status.equals(
-                RpcStatus.FAILURE, true)) {
+        var retryNum = 3
+        while (rpcService.closeSession(false).status
+            .equals(RpcStatus.FAILURE, true) && retryNum > 0
+        ) {
+            log.error("disconnect: graceful disconnect failed, retrying $retryNum times...")
+            retryNum--
+        }
+        // if we can't close the session, try to force terminate.
+        if (retryNum == 0) {
+            log.error("disconnect: trying to force-terminate the session.")
             rpcService.closeSession(true)
         }
         try {
@@ -98,14 +106,17 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
         checkAndReestablish()
 
         try {
-            return streamHandler.getFutureFromSendMessage(streamHandler.sendMessage(formattedRequest, messageId),
-                replyTimeout.toLong(), TimeUnit.SECONDS)
+            return streamHandler.getFutureFromSendMessage(
+                streamHandler.sendMessage(formattedRequest, messageId),
+                replyTimeout.toLong(), TimeUnit.SECONDS
+            )
         } catch (e: InterruptedException) {
-            Thread.currentThread().interrupt()
             throw NetconfException("$deviceInfo: Interrupted while waiting for reply for request: $formattedRequest", e)
         } catch (e: TimeoutException) {
-            throw NetconfException("$deviceInfo: Timed out while waiting for reply for request $formattedRequest after $replyTimeout sec.",
-                e)
+            throw NetconfException(
+                "$deviceInfo: Timed out while waiting for reply for request $formattedRequest after $replyTimeout sec.",
+                e
+            )
         } catch (e: ExecutionException) {
             log.warn("$deviceInfo: Closing session($sessionId) due to unexpected Error", e)
             try {
@@ -183,10 +194,9 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
         } catch (e: Exception) {
             throw NetconfException("$deviceInfo: Failed to establish SSH session", e)
         }
-
     }
 
-    //Needed to unit test connect method interacting with client.start in startClient() below
+    // Needed to unit test connect method interacting with client.start in startClient() below
     private fun setupNewSSHClient() {
         client = SshClient.setUpDefaultClient()
     }
@@ -194,8 +204,8 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
     private fun startClient() {
         setupNewSSHClient()
 
-        client.properties.putIfAbsent(FactoryManager.IDLE_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout.toLong()))
-        client.properties.putIfAbsent(FactoryManager.NIO2_READ_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout + 15L))
+        client.properties.putIfAbsent(CoreModuleProperties.IDLE_TIMEOUT.name, TimeUnit.SECONDS.toMillis(idleTimeout.toLong()))
+        client.properties.putIfAbsent(CoreModuleProperties.NIO2_READ_TIMEOUT.name, TimeUnit.SECONDS.toMillis(idleTimeout + 15L))
         client.start()
 
         startSession()
@@ -214,8 +224,13 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
     private fun authSession() {
         session.addPasswordIdentity(deviceInfo.password)
         session.auth().verify(connectionTimeout, TimeUnit.SECONDS)
-        val event = session.waitFor(ImmutableSet.of(ClientSession.ClientSessionEvent.WAIT_AUTH,
-            ClientSession.ClientSessionEvent.CLOSED, ClientSession.ClientSessionEvent.AUTHED), 0)
+        val event = session.waitFor(
+            ImmutableSet.of(
+                ClientSession.ClientSessionEvent.WAIT_AUTH,
+                ClientSession.ClientSessionEvent.CLOSED, ClientSession.ClientSessionEvent.AUTHED
+            ),
+            0
+        )
         if (!event.contains(ClientSession.ClientSessionEvent.AUTHED)) {
             throw NetconfException("$deviceInfo: Failed to authenticate session.")
         }
@@ -237,8 +252,10 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
 
     private fun setupHandler() {
         val sessionListener: NetconfSessionListener = NetconfSessionListenerImpl(this)
-        streamHandler = NetconfDeviceCommunicator(channel.invertedOut, channel.invertedIn, deviceInfo,
-            sessionListener, replies)
+        streamHandler = NetconfDeviceCommunicator(
+            channel.invertedOut, channel.invertedIn, deviceInfo,
+            sessionListener, replies
+        )
 
         exchangeHelloMessage()
     }
@@ -252,12 +269,13 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
 
         if (sessionIDMatcher.find()) {
             sessionId = sessionIDMatcher.group(1)
+            log.info("netconf exchangeHelloMessage sessionID: $sessionId")
         } else {
             throw NetconfException("$deviceInfo: Missing sessionId in server hello message: $serverHelloResponse")
         }
 
         val capabilityMatcher = NetconfMessageUtils.CAPABILITY_REGEX_PATTERN.matcher(serverHelloResponse)
-        while (capabilityMatcher.find()) { //TODO: refactor to add unit test easily for device capability accumulation.
+        while (capabilityMatcher.find()) { // TODO: refactor to add unit test easily for device capability accumulation.
             deviceCapabilities.add(capabilityMatcher.group(1))
         }
     }
@@ -287,6 +305,7 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
      */
     @Throws(IOException::class)
     private fun close() {
+        log.debug("close was called.")
         session.close()
         // Closes the socket which should interrupt the streamHandler
         channel.close()
@@ -302,9 +321,18 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
      * internal function for accessing errorReplies for testing.
      */
     internal fun getErrorReplies() = errorReplies
+
     internal fun clearErrorReplies() = errorReplies.clear()
     internal fun clearReplies() = replies.clear()
-    internal fun setClient(client: SshClient) { this.client = client }
-    internal fun setSession(session: ClientSession) { this.session = session }
-    internal fun setChannel(channel: ClientChannel) { this.channel = channel }
-}
\ No newline at end of file
+    internal fun setClient(client: SshClient) {
+        this.client = client
+    }
+
+    internal fun setSession(session: ClientSession) {
+        this.session = session
+    }
+
+    internal fun setChannel(channel: ClientChannel) {
+        this.channel = channel
+    }
+}