reduce CDS java security vulnerabilities
[ccsdk/cds.git] / ms / blueprintsprocessor / functions / netconf-executor / src / main / kotlin / org / onap / ccsdk / cds / blueprintsprocessor / functions / netconf / executor / core / NetconfSessionImpl.kt
1 /*
2  * Copyright © 2017-2019 AT&T, Bell Canada
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.core
18
19 import com.google.common.collect.ImmutableList
20 import com.google.common.collect.ImmutableSet
21 import org.apache.sshd.client.SshClient
22 import org.apache.sshd.client.channel.ClientChannel
23 import org.apache.sshd.client.session.ClientSession
24 import org.apache.sshd.core.CoreModuleProperties
25 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.DeviceInfo
26 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfException
27 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfRpcService
28 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSession
29 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSessionListener
30 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.NetconfMessageUtils
31 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.RpcMessageUtils
32 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.RpcStatus
33 import org.slf4j.LoggerFactory
34 import java.io.IOException
35 import java.util.Collections
36 import java.util.concurrent.CompletableFuture
37 import java.util.concurrent.ConcurrentHashMap
38 import java.util.concurrent.ExecutionException
39 import java.util.concurrent.TimeUnit
40 import java.util.concurrent.TimeoutException
41
42 class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcService: NetconfRpcService) :
43     NetconfSession {
44
45     private val log = LoggerFactory.getLogger(NetconfSessionImpl::class.java)
46
47     private val errorReplies: MutableList<String> = Collections.synchronizedList(mutableListOf())
48     private val replies: MutableMap<String, CompletableFuture<String>> = ConcurrentHashMap()
49     private val deviceCapabilities = mutableSetOf<String>()
50
51     private var connectionTimeout: Long = 0
52     private var replyTimeout: Int = 0
53     private var idleTimeout: Int = 0
54     private var sessionId: String? = null
55
56     private lateinit var session: ClientSession
57     private lateinit var client: SshClient
58     private lateinit var channel: ClientChannel
59     private lateinit var streamHandler: NetconfDeviceCommunicator
60
61     private var capabilities =
62         ImmutableList.of(RpcMessageUtils.NETCONF_10_CAPABILITY, RpcMessageUtils.NETCONF_11_CAPABILITY)
63
64     override fun connect() {
65         try {
66             log.info(
67                 "$deviceInfo: Connecting to Netconf Device with timeouts C:${deviceInfo.connectTimeout}, " +
68                     "R:${deviceInfo.replyTimeout}, I:${deviceInfo.idleTimeout}"
69             )
70             startConnection()
71             log.info("$deviceInfo: Connected to Netconf Device")
72         } catch (e: NetconfException) {
73             log.error("$deviceInfo: Netconf Device Connection Failed. ${e.message}")
74             throw NetconfException(e)
75         }
76     }
77
78     override fun disconnect() {
79         var retryNum = 3
80         while (rpcService.closeSession(false).status
81             .equals(RpcStatus.FAILURE, true) && retryNum > 0
82         ) {
83             log.error("disconnect: graceful disconnect failed, retrying $retryNum times...")
84             retryNum--
85         }
86         // if we can't close the session, try to force terminate.
87         if (retryNum == 0) {
88             log.error("disconnect: trying to force-terminate the session.")
89             rpcService.closeSession(true)
90         }
91         try {
92             close()
93         } catch (ioe: IOException) {
94             log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
95         }
96     }
97
98     override fun reconnect() {
99         disconnect()
100         connect()
101     }
102
103     override fun syncRpc(request: String, messageId: String): String {
104         val formattedRequest = NetconfMessageUtils.formatRPCRequest(request, messageId, deviceCapabilities)
105
106         checkAndReestablish()
107
108         try {
109             return streamHandler.getFutureFromSendMessage(
110                 streamHandler.sendMessage(formattedRequest, messageId),
111                 replyTimeout.toLong(), TimeUnit.SECONDS
112             )
113         } catch (e: InterruptedException) {
114             throw NetconfException("$deviceInfo: Interrupted while waiting for reply for request: $formattedRequest", e)
115         } catch (e: TimeoutException) {
116             throw NetconfException(
117                 "$deviceInfo: Timed out while waiting for reply for request $formattedRequest after $replyTimeout sec.",
118                 e
119             )
120         } catch (e: ExecutionException) {
121             log.warn("$deviceInfo: Closing session($sessionId) due to unexpected Error", e)
122             try {
123                 close()
124             } catch (ioe: IOException) {
125                 log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
126             }
127             clearErrorReplies()
128             clearReplies()
129
130             throw NetconfException("$deviceInfo: Closing session $sessionId for request $formattedRequest", e)
131         }
132     }
133
134     override fun asyncRpc(request: String, messageId: String): CompletableFuture<String> {
135         val formattedRequest = NetconfMessageUtils.formatRPCRequest(request, messageId, deviceCapabilities)
136
137         checkAndReestablish()
138
139         return streamHandler.sendMessage(formattedRequest, messageId).handleAsync { reply, t ->
140             if (t != null) {
141                 throw NetconfException(messageId, t)
142             }
143             reply
144         }
145     }
146
147     override fun checkAndReestablish() {
148         try {
149             when {
150                 client.isClosed -> {
151                     log.info("Trying to restart the whole SSH connection with {}", deviceInfo)
152                     clearReplies()
153                     startConnection()
154                 }
155                 session.isClosed -> {
156                     log.info("Trying to restart the session with {}", deviceInfo)
157                     clearReplies()
158                     startSession()
159                 }
160                 channel.isClosed -> {
161                     log.info("Trying to reopen the channel with {}", deviceInfo)
162                     clearReplies()
163                     openChannel()
164                 }
165                 else -> return
166             }
167         } catch (e: IOException) {
168             log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
169             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
170         } catch (e: IllegalStateException) {
171             log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
172             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
173         }
174     }
175
176     override fun getDeviceInfo(): DeviceInfo {
177         return deviceInfo
178     }
179
180     override fun getSessionId(): String {
181         return this.sessionId!!
182     }
183
184     override fun getDeviceCapabilitiesSet(): Set<String> {
185         return Collections.unmodifiableSet(deviceCapabilities)
186     }
187
188     private fun startConnection() {
189         connectionTimeout = deviceInfo.connectTimeout
190         replyTimeout = deviceInfo.replyTimeout
191         idleTimeout = deviceInfo.idleTimeout
192         try {
193             startClient()
194         } catch (e: Exception) {
195             throw NetconfException("$deviceInfo: Failed to establish SSH session", e)
196         }
197     }
198
199     // Needed to unit test connect method interacting with client.start in startClient() below
200     private fun setupNewSSHClient() {
201         client = SshClient.setUpDefaultClient()
202     }
203
204     private fun startClient() {
205         setupNewSSHClient()
206
207         client.properties.putIfAbsent(CoreModuleProperties.IDLE_TIMEOUT.name, TimeUnit.SECONDS.toMillis(idleTimeout.toLong()))
208         client.properties.putIfAbsent(CoreModuleProperties.NIO2_READ_TIMEOUT.name, TimeUnit.SECONDS.toMillis(idleTimeout + 15L))
209         client.start()
210
211         startSession()
212     }
213
214     private fun startSession() {
215         log.info("$deviceInfo: Starting SSH session")
216         val connectFuture = client.connect(deviceInfo.username, deviceInfo.ipAddress, deviceInfo.port)
217             .verify(connectionTimeout, TimeUnit.SECONDS)
218         session = connectFuture.session
219         log.info("$deviceInfo: SSH session created")
220
221         authSession()
222     }
223
224     private fun authSession() {
225         session.addPasswordIdentity(deviceInfo.password)
226         session.auth().verify(connectionTimeout, TimeUnit.SECONDS)
227         val event = session.waitFor(
228             ImmutableSet.of(
229                 ClientSession.ClientSessionEvent.WAIT_AUTH,
230                 ClientSession.ClientSessionEvent.CLOSED, ClientSession.ClientSessionEvent.AUTHED
231             ),
232             0
233         )
234         if (!event.contains(ClientSession.ClientSessionEvent.AUTHED)) {
235             throw NetconfException("$deviceInfo: Failed to authenticate session.")
236         }
237         log.info("$deviceInfo: SSH session authenticated")
238
239         openChannel()
240     }
241
242     private fun openChannel() {
243         channel = session.createSubsystemChannel("netconf")
244         val channelFuture = channel.open()
245         if (channelFuture.await(connectionTimeout, TimeUnit.SECONDS) && channelFuture.isOpened) {
246             log.info("$deviceInfo: SSH NETCONF subsystem channel opened")
247             setupHandler()
248         } else {
249             throw NetconfException("$deviceInfo: Failed to open SSH subsystem channel")
250         }
251     }
252
253     private fun setupHandler() {
254         val sessionListener: NetconfSessionListener = NetconfSessionListenerImpl(this)
255         streamHandler = NetconfDeviceCommunicator(
256             channel.invertedOut, channel.invertedIn, deviceInfo,
257             sessionListener, replies
258         )
259
260         exchangeHelloMessage()
261     }
262
263     private fun exchangeHelloMessage() {
264         sessionId = "-1"
265         val messageId = "-1"
266
267         val serverHelloResponse = syncRpc(NetconfMessageUtils.createHelloString(capabilities), messageId)
268         val sessionIDMatcher = NetconfMessageUtils.SESSION_ID_REGEX_PATTERN.matcher(serverHelloResponse)
269
270         if (sessionIDMatcher.find()) {
271             sessionId = sessionIDMatcher.group(1)
272             log.info("netconf exchangeHelloMessage sessionID: $sessionId")
273         } else {
274             throw NetconfException("$deviceInfo: Missing sessionId in server hello message: $serverHelloResponse")
275         }
276
277         val capabilityMatcher = NetconfMessageUtils.CAPABILITY_REGEX_PATTERN.matcher(serverHelloResponse)
278         while (capabilityMatcher.find()) { // TODO: refactor to add unit test easily for device capability accumulation.
279             deviceCapabilities.add(capabilityMatcher.group(1))
280         }
281     }
282
283     internal fun setStreamHandler(streamHandler: NetconfDeviceCommunicator) {
284         this.streamHandler = streamHandler
285     }
286
287     /**
288      * Add an error reply
289      * Used by {@link NetconfSessionListenerImpl}
290      */
291     internal fun addDeviceErrorReply(errReply: String) {
292         errorReplies.add(errReply)
293     }
294
295     /**
296      * Add a reply from the device
297      * Used by {@link NetconfSessionListenerImpl}
298      */
299     internal fun addDeviceReply(messageId: String, replyMsg: String) {
300         replies[messageId]?.complete(replyMsg)
301     }
302
303     /**
304      * Closes the session/channel/client
305      */
306     @Throws(IOException::class)
307     private fun close() {
308         log.debug("close was called.")
309         session.close()
310         // Closes the socket which should interrupt the streamHandler
311         channel.close()
312         client.close()
313     }
314
315     /**
316      * Internal function for accessing replies for testing.
317      */
318     internal fun getReplies() = replies
319
320     /**
321      * internal function for accessing errorReplies for testing.
322      */
323     internal fun getErrorReplies() = errorReplies
324
325     internal fun clearErrorReplies() = errorReplies.clear()
326     internal fun clearReplies() = replies.clear()
327     internal fun setClient(client: SshClient) {
328         this.client = client
329     }
330
331     internal fun setSession(session: ClientSession) {
332         this.session = session
333     }
334
335     internal fun setChannel(channel: ClientChannel) {
336         this.channel = channel
337     }
338 }