Fix interrupt state left ON on NetconfSessionImpl.syncRpc()
[ccsdk/cds.git] / ms / blueprintsprocessor / functions / netconf-executor / src / main / kotlin / org / onap / ccsdk / cds / blueprintsprocessor / functions / netconf / executor / core / NetconfSessionImpl.kt
1 /*
2  * Copyright © 2017-2019 AT&T, Bell Canada
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.core
18
19 import com.google.common.collect.ImmutableList
20 import com.google.common.collect.ImmutableSet
21 import org.apache.sshd.client.SshClient
22 import org.apache.sshd.client.channel.ClientChannel
23 import org.apache.sshd.client.session.ClientSession
24 import org.apache.sshd.common.FactoryManager
25 import org.apache.sshd.server.keyprovider.SimpleGeneratorHostKeyProvider
26 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.DeviceInfo
27 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfException
28 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfReceivedEvent
29 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfRpcService
30 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSession
31 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.NetconfSessionListener
32 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.NetconfMessageUtils
33 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.RpcMessageUtils
34 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.RpcStatus
35 import org.slf4j.LoggerFactory
36 import java.io.IOException
37 import java.util.Collections
38 import java.util.concurrent.CompletableFuture
39 import java.util.concurrent.ConcurrentHashMap
40 import java.util.concurrent.ExecutionException
41 import java.util.concurrent.TimeUnit
42 import java.util.concurrent.TimeoutException
43
44 class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcService: NetconfRpcService) :
45     NetconfSession {
46
47     private val log = LoggerFactory.getLogger(NetconfSessionImpl::class.java)
48
49     private val errorReplies: MutableList<String> = Collections.synchronizedList(mutableListOf())
50     private val replies: MutableMap<String, CompletableFuture<String>> = ConcurrentHashMap()
51     private val deviceCapabilities = mutableSetOf<String>()
52
53     private var connectionTimeout: Long = 0
54     private var replyTimeout: Int = 0
55     private var idleTimeout: Int = 0
56     private var sessionId: String? = null
57
58     private lateinit var session: ClientSession
59     private lateinit var client: SshClient
60     private lateinit var channel: ClientChannel
61     private lateinit var streamHandler: NetconfDeviceCommunicator
62
63     private var capabilities =
64         ImmutableList.of(RpcMessageUtils.NETCONF_10_CAPABILITY, RpcMessageUtils.NETCONF_11_CAPABILITY)
65
66     override fun connect() {
67         try {
68             log.info("$deviceInfo: Connecting to Netconf Device with timeouts C:${deviceInfo.connectTimeout}, " +
69                     "R:${deviceInfo.replyTimeout}, I:${deviceInfo.idleTimeout}")
70             startConnection()
71             log.info("$deviceInfo: Connected to Netconf Device")
72         } catch (e: NetconfException) {
73             log.error("$deviceInfo: Netconf Device Connection Failed. ${e.message}")
74             throw NetconfException(e)
75         }
76     }
77
78     override fun disconnect() {
79         if (rpcService.closeSession(false).status.equals(
80                 RpcStatus.FAILURE, true)) {
81             rpcService.closeSession(true)
82         }
83         try {
84             close()
85         } catch (ioe: IOException) {
86             log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
87         }
88     }
89
90     override fun reconnect() {
91         disconnect()
92         connect()
93     }
94
95     override fun syncRpc(request: String, messageId: String): String {
96         val formattedRequest = NetconfMessageUtils.formatRPCRequest(request, messageId, deviceCapabilities)
97
98         checkAndReestablish()
99
100         try {
101             return streamHandler.getFutureFromSendMessage(streamHandler.sendMessage(formattedRequest, messageId),
102                 replyTimeout.toLong(), TimeUnit.SECONDS)
103         } catch (e: InterruptedException) {
104             throw NetconfException("$deviceInfo: Interrupted while waiting for reply for request: $formattedRequest", e)
105         } catch (e: TimeoutException) {
106             throw NetconfException("$deviceInfo: Timed out while waiting for reply for request $formattedRequest after $replyTimeout sec.",
107                 e)
108         } catch (e: ExecutionException) {
109             log.warn("$deviceInfo: Closing session($sessionId) due to unexpected Error", e)
110             try {
111                 close()
112             } catch (ioe: IOException) {
113                 log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
114             }
115             clearErrorReplies()
116             clearReplies()
117
118             throw NetconfException("$deviceInfo: Closing session $sessionId for request $formattedRequest", e)
119         }
120     }
121
122     override fun asyncRpc(request: String, messageId: String): CompletableFuture<String> {
123         val formattedRequest = NetconfMessageUtils.formatRPCRequest(request, messageId, deviceCapabilities)
124
125         checkAndReestablish()
126
127         return streamHandler.sendMessage(formattedRequest, messageId).handleAsync { reply, t ->
128             if (t != null) {
129                 throw NetconfException(messageId, t)
130             }
131             reply
132         }
133     }
134
135     override fun checkAndReestablish() {
136         try {
137             when {
138                 client.isClosed -> {
139                     log.info("Trying to restart the whole SSH connection with {}", deviceInfo)
140                     clearReplies()
141                     startConnection()
142                 }
143                 session.isClosed -> {
144                     log.info("Trying to restart the session with {}", deviceInfo)
145                     clearReplies()
146                     startSession()
147                 }
148                 channel.isClosed -> {
149                     log.info("Trying to reopen the channel with {}", deviceInfo)
150                     clearReplies()
151                     openChannel()
152                 }
153                 else -> return
154             }
155         } catch (e: IOException) {
156             log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
157             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
158         } catch (e: IllegalStateException) {
159             log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
160             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
161         }
162     }
163
164     override fun getDeviceInfo(): DeviceInfo {
165         return deviceInfo
166     }
167
168     override fun getSessionId(): String {
169         return this.sessionId!!
170     }
171
172     override fun getDeviceCapabilitiesSet(): Set<String> {
173         return Collections.unmodifiableSet(deviceCapabilities)
174     }
175
176     private fun startConnection() {
177         connectionTimeout = deviceInfo.connectTimeout
178         replyTimeout = deviceInfo.replyTimeout
179         idleTimeout = deviceInfo.idleTimeout
180         try {
181             startClient()
182         } catch (e: Exception) {
183             throw NetconfException("$deviceInfo: Failed to establish SSH session", e)
184         }
185
186     }
187
188     //Needed to unit test connect method interacting with client.start in startClient() below
189     private fun setupNewSSHClient() {
190         client = SshClient.setUpDefaultClient()
191     }
192
193     private fun startClient() {
194         setupNewSSHClient()
195
196         client.properties.putIfAbsent(FactoryManager.IDLE_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout.toLong()))
197         client.properties.putIfAbsent(FactoryManager.NIO2_READ_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout + 15L))
198         client.start()
199
200         startSession()
201     }
202
203     private fun startSession() {
204         log.info("$deviceInfo: Starting SSH session")
205         val connectFuture = client.connect(deviceInfo.username, deviceInfo.ipAddress, deviceInfo.port)
206             .verify(connectionTimeout, TimeUnit.SECONDS)
207         session = connectFuture.session
208         log.info("$deviceInfo: SSH session created")
209
210         authSession()
211     }
212
213     private fun authSession() {
214         session.addPasswordIdentity(deviceInfo.password)
215         session.auth().verify(connectionTimeout, TimeUnit.SECONDS)
216         val event = session.waitFor(ImmutableSet.of(ClientSession.ClientSessionEvent.WAIT_AUTH,
217             ClientSession.ClientSessionEvent.CLOSED, ClientSession.ClientSessionEvent.AUTHED), 0)
218         if (!event.contains(ClientSession.ClientSessionEvent.AUTHED)) {
219             throw NetconfException("$deviceInfo: Failed to authenticate session.")
220         }
221         log.info("$deviceInfo: SSH session authenticated")
222
223         openChannel()
224     }
225
226     private fun openChannel() {
227         channel = session.createSubsystemChannel("netconf")
228         val channelFuture = channel.open()
229         if (channelFuture.await(connectionTimeout, TimeUnit.SECONDS) && channelFuture.isOpened) {
230             log.info("$deviceInfo: SSH NETCONF subsystem channel opened")
231             setupHandler()
232         } else {
233             throw NetconfException("$deviceInfo: Failed to open SSH subsystem channel")
234         }
235     }
236
237     private fun setupHandler() {
238         val sessionListener: NetconfSessionListener = NetconfSessionListenerImpl(this)
239         streamHandler = NetconfDeviceCommunicator(channel.invertedOut, channel.invertedIn, deviceInfo,
240             sessionListener, replies)
241
242         exchangeHelloMessage()
243     }
244
245     private fun exchangeHelloMessage() {
246         sessionId = "-1"
247         val messageId = "-1"
248
249         val serverHelloResponse = syncRpc(NetconfMessageUtils.createHelloString(capabilities), messageId)
250         val sessionIDMatcher = NetconfMessageUtils.SESSION_ID_REGEX_PATTERN.matcher(serverHelloResponse)
251
252         if (sessionIDMatcher.find()) {
253             sessionId = sessionIDMatcher.group(1)
254         } else {
255             throw NetconfException("$deviceInfo: Missing sessionId in server hello message: $serverHelloResponse")
256         }
257
258         val capabilityMatcher = NetconfMessageUtils.CAPABILITY_REGEX_PATTERN.matcher(serverHelloResponse)
259         while (capabilityMatcher.find()) { //TODO: refactor to add unit test easily for device capability accumulation.
260             deviceCapabilities.add(capabilityMatcher.group(1))
261         }
262     }
263
264     internal fun setStreamHandler(streamHandler: NetconfDeviceCommunicator) {
265         this.streamHandler = streamHandler
266     }
267
268     /**
269      * Add an error reply
270      * Used by {@link NetconfSessionListenerImpl}
271      */
272     internal fun addDeviceErrorReply(errReply: String) {
273         errorReplies.add(errReply)
274     }
275
276     /**
277      * Add a reply from the device
278      * Used by {@link NetconfSessionListenerImpl}
279      */
280     internal fun addDeviceReply(messageId: String, replyMsg: String) {
281         replies[messageId]?.complete(replyMsg)
282     }
283
284     /**
285      * Closes the session/channel/client
286      */
287     @Throws(IOException::class)
288     private fun close() {
289         session.close()
290         // Closes the socket which should interrupt the streamHandler
291         channel.close()
292         client.close()
293     }
294
295     /**
296      * Internal function for accessing replies for testing.
297      */
298     internal fun getReplies() = replies
299
300     /**
301      * internal function for accessing errorReplies for testing.
302      */
303     internal fun getErrorReplies() = errorReplies
304     internal fun clearErrorReplies() = errorReplies.clear()
305     internal fun clearReplies() = replies.clear()
306     internal fun setClient(client: SshClient) { this.client = client }
307     internal fun setSession(session: ClientSession) { this.session = session }
308     internal fun setChannel(channel: ClientChannel) { this.channel = channel }
309 }