a0f653591076dcafe18f0c32b3baeb0d65115c85
[ccsdk/cds.git] /
1 /*
2  * Copyright © 2017-2019 AT&T, Bell Canada
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16
17 package org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.core
18
19 import com.google.common.collect.ImmutableList
20 import com.google.common.collect.ImmutableSet
21 import org.apache.sshd.client.SshClient
22 import org.apache.sshd.client.channel.ClientChannel
23 import org.apache.sshd.client.session.ClientSession
24 import org.apache.sshd.common.FactoryManager
25 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.api.*
26 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.NetconfMessageUtils
27 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.RpcMessageUtils
28 import org.onap.ccsdk.cds.blueprintsprocessor.functions.netconf.executor.utils.RpcStatus
29 import org.slf4j.LoggerFactory
30 import java.io.IOException
31 import java.util.*
32 import java.util.concurrent.*
33
34 class NetconfSessionImpl(private val deviceInfo: DeviceInfo, private val rpcService: NetconfRpcService) :
35     NetconfSession {
36
37     private val log = LoggerFactory.getLogger(NetconfSessionImpl::class.java)
38
39     private val errorReplies: MutableList<String> = Collections.synchronizedList(mutableListOf())
40     private val replies: MutableMap<String, CompletableFuture<String>> = ConcurrentHashMap()
41     private val deviceCapabilities = mutableSetOf<String>()
42
43     private var connectionTimeout: Long = 0
44     private var replyTimeout: Int = 0
45     private var idleTimeout: Int = 0
46     private var sessionId: String? = null
47
48     private lateinit var session: ClientSession
49     private lateinit var client: SshClient
50     private lateinit var channel: ClientChannel
51     private lateinit var streamHandler: NetconfDeviceCommunicator
52
53     private var capabilities =
54         ImmutableList.of(RpcMessageUtils.NETCONF_10_CAPABILITY, RpcMessageUtils.NETCONF_11_CAPABILITY)
55
56     override fun connect() {
57         try {
58             log.info("$deviceInfo: Connecting to Netconf Device with timeouts C:${deviceInfo.connectTimeout}, " +
59                     "R:${deviceInfo.replyTimeout}, I:${deviceInfo.idleTimeout}")
60             startConnection()
61             log.info("$deviceInfo: Connected to Netconf Device")
62         } catch (e: NetconfException) {
63             log.error("$deviceInfo: Netconf Device Connection Failed. ${e.message}")
64             throw NetconfException(e)
65         }
66     }
67
68     override fun disconnect() {
69         if (rpcService.closeSession(false).status.equals(
70                 RpcStatus.FAILURE, true)) {
71             rpcService.closeSession(true)
72         }
73         try {
74             close()
75         } catch (ioe: IOException) {
76             log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
77         }
78     }
79
80     override fun reconnect() {
81         disconnect()
82         connect()
83     }
84
85     override fun syncRpc(request: String, messageId: String): String {
86         val formattedRequest = NetconfMessageUtils.formatRPCRequest(request, messageId, deviceCapabilities)
87
88         checkAndReestablish()
89
90         try {
91             return streamHandler.getFutureFromSendMessage(streamHandler.sendMessage(formattedRequest, messageId),
92                 replyTimeout.toLong(), TimeUnit.SECONDS)
93         } catch (e: InterruptedException) {
94             throw NetconfException("$deviceInfo: Interrupted while waiting for reply for request: $formattedRequest", e)
95         } catch (e: TimeoutException) {
96             throw NetconfException("$deviceInfo: Timed out while waiting for reply for request $formattedRequest after $replyTimeout sec.",
97                 e)
98         } catch (e: ExecutionException) {
99             log.warn("$deviceInfo: Closing session($sessionId) due to unexpected Error", e)
100             try {
101                 close()
102             } catch (ioe: IOException) {
103                 log.warn("$deviceInfo: Error closing session($sessionId) for host($deviceInfo)", ioe)
104             }
105             clearErrorReplies()
106             clearReplies()
107
108             throw NetconfException("$deviceInfo: Closing session $sessionId for request $formattedRequest", e)
109         }
110     }
111
112     override fun asyncRpc(request: String, messageId: String): CompletableFuture<String> {
113         val formattedRequest = NetconfMessageUtils.formatRPCRequest(request, messageId, deviceCapabilities)
114
115         checkAndReestablish()
116
117         return streamHandler.sendMessage(formattedRequest, messageId).handleAsync { reply, t ->
118             if (t != null) {
119                 throw NetconfException(messageId, t)
120             }
121             reply
122         }
123     }
124
125     override fun checkAndReestablish() {
126         try {
127             when {
128                 client.isClosed -> {
129                     log.info("Trying to restart the whole SSH connection with {}", deviceInfo)
130                     clearReplies()
131                     startConnection()
132                 }
133                 session.isClosed -> {
134                     log.info("Trying to restart the session with {}", deviceInfo)
135                     clearReplies()
136                     startSession()
137                 }
138                 channel.isClosed -> {
139                     log.info("Trying to reopen the channel with {}", deviceInfo)
140                     clearReplies()
141                     openChannel()
142                 }
143                 else -> return
144             }
145         } catch (e: IOException) {
146             log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
147             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
148         } catch (e: IllegalStateException) {
149             log.error("Can't reopen connection for device {} error: {}", deviceInfo, e.message)
150             throw NetconfException(String.format("Cannot re-open the connection with device (%s)", deviceInfo), e)
151         }
152     }
153
154     override fun getDeviceInfo(): DeviceInfo {
155         return deviceInfo
156     }
157
158     override fun getSessionId(): String {
159         return this.sessionId!!
160     }
161
162     override fun getDeviceCapabilitiesSet(): Set<String> {
163         return Collections.unmodifiableSet(deviceCapabilities)
164     }
165
166     private fun startConnection() {
167         connectionTimeout = deviceInfo.connectTimeout
168         replyTimeout = deviceInfo.replyTimeout
169         idleTimeout = deviceInfo.idleTimeout
170         try {
171             startClient()
172         } catch (e: Exception) {
173             throw NetconfException("$deviceInfo: Failed to establish SSH session", e)
174         }
175
176     }
177
178     //Needed to unit test connect method interacting with client.start in startClient() below
179     private fun setupNewSSHClient() {
180         client = SshClient.setUpDefaultClient()
181     }
182
183     private fun startClient() {
184         setupNewSSHClient()
185
186         client.properties.putIfAbsent(FactoryManager.IDLE_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout.toLong()))
187         client.properties.putIfAbsent(FactoryManager.NIO2_READ_TIMEOUT, TimeUnit.SECONDS.toMillis(idleTimeout + 15L))
188         client.start()
189
190         startSession()
191     }
192
193     private fun startSession() {
194         log.info("$deviceInfo: Starting SSH session")
195         val connectFuture = client.connect(deviceInfo.username, deviceInfo.ipAddress, deviceInfo.port)
196             .verify(connectionTimeout, TimeUnit.SECONDS)
197         session = connectFuture.session
198         log.info("$deviceInfo: SSH session created")
199
200         authSession()
201     }
202
203     private fun authSession() {
204         session.addPasswordIdentity(deviceInfo.password)
205         session.auth().verify(connectionTimeout, TimeUnit.SECONDS)
206         val event = session.waitFor(ImmutableSet.of(ClientSession.ClientSessionEvent.WAIT_AUTH,
207             ClientSession.ClientSessionEvent.CLOSED, ClientSession.ClientSessionEvent.AUTHED), 0)
208         if (!event.contains(ClientSession.ClientSessionEvent.AUTHED)) {
209             throw NetconfException("$deviceInfo: Failed to authenticate session.")
210         }
211         log.info("$deviceInfo: SSH session authenticated")
212
213         openChannel()
214     }
215
216     private fun openChannel() {
217         channel = session.createSubsystemChannel("netconf")
218         val channelFuture = channel.open()
219         if (channelFuture.await(connectionTimeout, TimeUnit.SECONDS) && channelFuture.isOpened) {
220             log.info("$deviceInfo: SSH NETCONF subsystem channel opened")
221             setupHandler()
222         } else {
223             throw NetconfException("$deviceInfo: Failed to open SSH subsystem channel")
224         }
225     }
226
227     private fun setupHandler() {
228         val sessionListener: NetconfSessionListener = NetconfSessionListenerImpl(this)
229         streamHandler = NetconfDeviceCommunicator(channel.invertedOut, channel.invertedIn, deviceInfo,
230             sessionListener, replies)
231
232         exchangeHelloMessage()
233     }
234
235     private fun exchangeHelloMessage() {
236         sessionId = "-1"
237         val messageId = "-1"
238
239         val serverHelloResponse = syncRpc(NetconfMessageUtils.createHelloString(capabilities), messageId)
240         val sessionIDMatcher = NetconfMessageUtils.SESSION_ID_REGEX_PATTERN.matcher(serverHelloResponse)
241
242         if (sessionIDMatcher.find()) {
243             sessionId = sessionIDMatcher.group(1)
244         } else {
245             throw NetconfException("$deviceInfo: Missing sessionId in server hello message: $serverHelloResponse")
246         }
247
248         val capabilityMatcher = NetconfMessageUtils.CAPABILITY_REGEX_PATTERN.matcher(serverHelloResponse)
249         while (capabilityMatcher.find()) { //TODO: refactor to add unit test easily for device capability accumulation.
250             deviceCapabilities.add(capabilityMatcher.group(1))
251         }
252     }
253
254     internal fun setStreamHandler(streamHandler: NetconfDeviceCommunicator) {
255         this.streamHandler = streamHandler
256     }
257
258     /**
259      * Add an error reply
260      * Used by {@link NetconfSessionListenerImpl}
261      */
262     internal fun addDeviceErrorReply(errReply: String) {
263         errorReplies.add(errReply)
264     }
265
266     /**
267      * Add a reply from the device
268      * Used by {@link NetconfSessionListenerImpl}
269      */
270     internal fun addDeviceReply(messageId: String, replyMsg: String) {
271         replies[messageId]?.complete(replyMsg)
272     }
273
274     /**
275      * Closes the session/channel/client
276      */
277     @Throws(IOException::class)
278     private fun close() {
279         session.close()
280         // Closes the socket which should interrupt the streamHandler
281         channel.close()
282         client.close()
283     }
284
285     /**
286      * Internal function for accessing replies for testing.
287      */
288     internal fun getReplies() = replies
289
290     /**
291      * internal function for accessing errorReplies for testing.
292      */
293     internal fun getErrorReplies() = errorReplies
294     internal fun clearErrorReplies() = errorReplies.clear()
295     internal fun clearReplies() = replies.clear()
296     internal fun setClient(client: SshClient) { this.client = client }
297     internal fun setSession(session: ClientSession) { this.session = session }
298     internal fun setChannel(channel: ClientChannel) { this.channel = channel }
299 }