reduce CDS java security vulnerabilities
[ccsdk/cds.git] / ms / blueprintsprocessor / functions / netconf-executor / src / main / kotlin / org / onap / ccsdk / cds / blueprintsprocessor / functions / netconf / executor / core / NetconfSessionImpl.kt
index d1ecb4f..31d90fd 100644 (file)
@@ -19,15 +19,11 @@ package org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.core
 import com.google.common.collect.ImmutableList
 import com.google.common.collect.ImmutableSet
 import org.apache.sshd.client.SshClient
-import org.apache.sshd.client.channel.ChannelSubsystem
 import org.apache.sshd.client.channel.ClientChannel
 import org.apache.sshd.client.session.ClientSession
-import org.apache.sshd.client.session.ClientSessionImpl
-import org.apache.sshd.common.FactoryManager
-import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider
+import org.apache.sshd.core.CoreModuleProperties
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.DeviceInfo
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfException
-import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfReceivedEvent
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfRpcService
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSession
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSessionListener
@@ -36,22 +32,21 @@ import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.R
 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.RpcStatus
 import org.slf4j.LoggerFactory
 import java.io.IOException
-import java.util.*
+import java.util.Collections
 import java.util.concurrent.CompletableFuture
 import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.ExecutionException
 import java.util.concurrent.TimeUnit
 import java.util.concurrent.TimeoutException
-import java.util.concurrent.atomic.AtomicReference
 
 class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcService: NetconfRpcService) :
     NetconfSession {
 
     private val log = LoggerFactory.getLogger(NetconfSessionImpl::class.java)
 
-    private val errorReplies: MutableList<String> = Collections.synchronizedList(listOf())
+    private val errorReplies: MutableList<String> = Collections.synchronizedList(mutableListOf())
     private val replies: MutableMap<String, CompletableFuture<String>> = ConcurrentHashMap()
-    private val deviceCapabilities = setOf<String>()
+    private val deviceCapabilities = mutableSetOf<String>()
 
     private var connectionTimeout: Long = 0
     private var replyTimeout: Int = 0
@@ -68,8 +63,10 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
 
     override fun connect() {
         try {
-            log.info("$deviceInfo: Connecting to Netconf Device with timeouts C:${deviceInfo.connectTimeout}, " +
-                    "R:${deviceInfo.replyTimeout}, I:${deviceInfo.idleTimeout}")
+            log.info(
+                "$deviceInfo: Connecting to Netconf Device with timeouts C:${deviceInfo.connectTimeout}, " +
+                    "R:${deviceInfo.replyTimeout}, I:${deviceInfo.idleTimeout}"
+            )
             startConnection()
             log.info("$deviceInfo: Connected to Netconf Device")
         } catch (e: NetconfException) {
@@ -79,15 +76,23 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
     }
 
     override fun disconnect() {
-        if (rpcService.closeSession(false).status.equals(
-                RpcStatus.FAILURE, true)) {
+        var retryNum = 3
+        while (rpcService.closeSession(false).status
+            .equals(RpcStatus.FAILURE, true) && retryNum > 0
+        ) {
+            log.error("disconnect: graceful disconnect failed, retrying $retryNum times...")
+            retryNum--
+        }
+        // if we can't close the session, try to force terminate.
+        if (retryNum == 0) {
+            log.error("disconnect: trying to force-terminate the session.")
             rpcService.closeSession(true)
         }
-
-        session.close()
-        // Closes the socket which should interrupt the streamHandler
-        channel.close()
-        client.close()
+        try {
+            close()
+        } catch (ioe: IOException) {
+            log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
+        }
     }
 
     override fun reconnect() {
@@ -101,29 +106,26 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
         checkAndReestablish()
 
         try {
-            return streamHandler.sendMessage(formattedRequest, messageId).get(replyTimeout.toLong(), TimeUnit.SECONDS)
-//            replies.remove(messageId)
+            return streamHandler.getFutureFromSendMessage(
+                streamHandler.sendMessage(formattedRequest, messageId),
+                replyTimeout.toLong(), TimeUnit.SECONDS
+            )
         } catch (e: InterruptedException) {
-            Thread.currentThread().interrupt()
             throw NetconfException("$deviceInfo: Interrupted while waiting for reply for request: $formattedRequest", e)
         } catch (e: TimeoutException) {
-            throw NetconfException("$deviceInfo: Timed out while waiting for reply for request $formattedRequest after $replyTimeout sec.",
-                e)
+            throw NetconfException(
+                "$deviceInfo: Timed out while waiting for reply for request $formattedRequest after $replyTimeout sec.",
+                e
+            )
         } catch (e: ExecutionException) {
             log.warn("$deviceInfo: Closing session($sessionId) due to unexpected Error", e)
             try {
-                session.close()
-                // Closes the socket which should interrupt the streamHandler
-                channel.close()
-                client.close()
+                close()
             } catch (ioe: IOException) {
                 log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
             }
-
-//            NetconfReceivedEvent(NetconfReceivedEvent.Type.SESSION_CLOSED, "",
-//                "Closed due to unexpected error " + e.cause, "-1", deviceInfo)
-            errorReplies.clear() // move to cleanUp()?
-            replies.clear()
+            clearErrorReplies()
+            clearReplies()
 
             throw NetconfException("$deviceInfo: Closing session $sessionId for request $formattedRequest", e)
         }
@@ -144,29 +146,31 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
 
     override fun checkAndReestablish() {
         try {
-            if (client.isClosed) {
-                log.info("Trying to restart the whole SSH connection with {}", deviceInfo)
-                replies.clear()
-                startConnection()
-            } else if (session.isClosed) {
-                log.info("Trying to restart the session with {}", deviceInfo)
-                replies.clear()
-                startSession()
-            } else if (channel.isClosed) {
-                log.info("Trying to reopen the channel with {}", deviceInfo)
-                replies.clear()
-                openChannel()
-            } else {
-                return
+            when {
+                client.isClosed -> {
+                    log.info("Trying to restart the whole SSH connection with {}", deviceInfo)
+                    clearReplies()
+                    startConnection()
+                }
+                session.isClosed -> {
+                    log.info("Trying to restart the session with {}", deviceInfo)
+                    clearReplies()
+                    startSession()
+                }
+                channel.isClosed -> {
+                    log.info("Trying to reopen the channel with {}", deviceInfo)
+                    clearReplies()
+                    openChannel()
+                }
+                else -> return
             }
         } catch (e: IOException) {
-            log.error("Can't reopen connection for device {}", e.message)
+            log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
         } catch (e: IllegalStateException) {
-            log.error("Can't reopen connection for device {}", e.message)
+            log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
         }
-
     }
 
     override fun getDeviceInfo(): DeviceInfo {
@@ -190,14 +194,18 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
         } catch (e: Exception) {
             throw NetconfException("$deviceInfo: Failed to establish SSH session", e)
         }
+    }
 
+    // Needed to unit test connect method interacting with client.start in startClient() below
+    private fun setupNewSSHClient() {
+        client = SshClient.setUpDefaultClient()
     }
 
     private fun startClient() {
-        client = SshClient.setUpDefaultClient()
-        client.properties.putIfAbsent(FactoryManager.IDLE_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout.toLong()))
-        client.properties.putIfAbsent(FactoryManager.NIO2_READ_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout + 15L))
-        client.keyPairProvider = SimpleGeneratorHostKeyProvider()
+        setupNewSSHClient()
+
+        client.properties.putIfAbsent(CoreModuleProperties.IDLE_TIMEOUT.name, TimeUnit.SECONDS.toMillis(idleTimeout.toLong()))
+        client.properties.putIfAbsent(CoreModuleProperties.NIO2_READ_TIMEOUT.name, TimeUnit.SECONDS.toMillis(idleTimeout + 15L))
         client.start()
 
         startSession()
@@ -216,8 +224,13 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
     private fun authSession() {
         session.addPasswordIdentity(deviceInfo.password)
         session.auth().verify(connectionTimeout, TimeUnit.SECONDS)
-        val event = session.waitFor(ImmutableSet.of(ClientSession.ClientSessionEvent.WAIT_AUTH,
-            ClientSession.ClientSessionEvent.CLOSED, ClientSession.ClientSessionEvent.AUTHED), 0)
+        val event = session.waitFor(
+            ImmutableSet.of(
+                ClientSession.ClientSessionEvent.WAIT_AUTH,
+                ClientSession.ClientSessionEvent.CLOSED, ClientSession.ClientSessionEvent.AUTHED
+            ),
+            0
+        )
         if (!event.contains(ClientSession.ClientSessionEvent.AUTHED)) {
             throw NetconfException("$deviceInfo: Failed to authenticate session.")
         }
@@ -238,9 +251,11 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
     }
 
     private fun setupHandler() {
-        val sessionListener: NetconfSessionListener = NetconfSessionListenerImpl()
-        streamHandler = NetconfDeviceCommunicator(channel.invertedOut, channel.invertedIn, deviceInfo,
-            sessionListener, replies)
+        val sessionListener: NetconfSessionListener = NetconfSessionListenerImpl(this)
+        streamHandler = NetconfDeviceCommunicator(
+            channel.invertedOut, channel.invertedIn, deviceInfo,
+            sessionListener, replies
+        )
 
         exchangeHelloMessage()
     }
@@ -254,34 +269,70 @@ class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcServ
 
         if (sessionIDMatcher.find()) {
             sessionId = sessionIDMatcher.group(1)
+            log.info("netconf exchangeHelloMessage sessionID: $sessionId")
         } else {
             throw NetconfException("$deviceInfo: Missing sessionId in server hello message: $serverHelloResponse")
         }
 
         val capabilityMatcher = NetconfMessageUtils.CAPABILITY_REGEX_PATTERN.matcher(serverHelloResponse)
-        while (capabilityMatcher.find()) {
-            deviceCapabilities.plus(capabilityMatcher.group(1))
+        while (capabilityMatcher.find()) { // TODO: refactor to add unit test easily for device capability accumulation.
+            deviceCapabilities.add(capabilityMatcher.group(1))
         }
     }
 
-    inner class NetconfSessionListenerImpl : NetconfSessionListener {
-        override fun notify(event: NetconfReceivedEvent) {
-            val messageId = event.getMessageID()
+    internal fun setStreamHandler(streamHandler: NetconfDeviceCommunicator) {
+        this.streamHandler = streamHandler
+    }
 
-            when (event.getType()) {
-                NetconfReceivedEvent.Type.DEVICE_UNREGISTERED -> disconnect()
-                NetconfReceivedEvent.Type.DEVICE_ERROR -> errorReplies.add(event.getMessagePayload())
-                NetconfReceivedEvent.Type.DEVICE_REPLY -> replies[messageId]?.complete(event.getMessagePayload())
-                NetconfReceivedEvent.Type.SESSION_CLOSED -> disconnect()
-            }
-        }
+    /**
+     * Add an error reply
+     * Used by {@link NetconfSessionListenerImpl}
+     */
+    internal fun addDeviceErrorReply(errReply: String) {
+        errorReplies.add(errReply)
     }
 
-    fun sessionstatus(state:String): Boolean{
-        return when (state){
-            "Close" -> channel.isClosed
-            "Open" -> channel.isOpen
-            else -> false
-        }
+    /**
+     * Add a reply from the device
+     * Used by {@link NetconfSessionListenerImpl}
+     */
+    internal fun addDeviceReply(messageId: String, replyMsg: String) {
+        replies[messageId]?.complete(replyMsg)
+    }
+
+    /**
+     * Closes the session/channel/client
+     */
+    @Throws(IOException::class)
+    private fun close() {
+        log.debug("close was called.")
+        session.close()
+        // Closes the socket which should interrupt the streamHandler
+        channel.close()
+        client.close()
+    }
+
+    /**
+     * Internal function for accessing replies for testing.
+     */
+    internal fun getReplies() = replies
+
+    /**
+     * internal function for accessing errorReplies for testing.
+     */
+    internal fun getErrorReplies() = errorReplies
+
+    internal fun clearErrorReplies() = errorReplies.clear()
+    internal fun clearReplies() = replies.clear()
+    internal fun setClient(client: SshClient) {
+        this.client = client
+    }
+
+    internal fun setSession(session: ClientSession) {
+        this.session = session
+    }
+
+    internal fun setChannel(channel: ClientChannel) {
+        this.channel = channel
     }
-}
\ No newline at end of file
+}