b3cd46e0746ec13ea39948934318441a0d9dc16f
[msb/apigateway.git] / openresty-ext / src / assembly / resources / openresty / nginx / luaext / loadbalance / policy / consistent_hash.lua
1 --[[
2
3     Copyright (C) 2018 ZTE, Inc. and others. All rights reserved. (ZTE)
4
5     Licensed under the Apache License, Version 2.0 (the "License");
6     you may not use this file except in compliance with the License.
7     You may obtain a copy of the License at
8
9             http://www.apache.org/licenses/LICENSE-2.0
10
11     Unless required by applicable law or agreed to in writing, software
12     distributed under the License is distributed on an "AS IS" BASIS,
13     WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14     See the License for the specific language governing permissions and
15     limitations under the License.
16
17 --]]
18
19 local _M = {}
20 _M._VERSION = '1.0.0'
21
22 local floor      = math.floor
23 local str_byte   = string.byte
24 local tab_sort   = table.sort
25 local tab_insert = table.insert
26
27 local MOD       = 2 ^ 32
28 local REPLICAS  = 20
29 local LUCKY_NUM = 13
30
31
32 local tbl_util  = require('lib.utils.table_util')
33 local tbl_isempty = tbl_util.isempty
34 local tbl_isequal = require('pl.tablex')
35 local peerwatcher = require "core.peerwatcher"
36 local ngx_var = ngx.var
37 local hash_data = {}
38
39 local function hash_string(str)
40     local key = 0
41     for i = 1, #str do
42         key = (key * 31 + str_byte(str, i)) % MOD
43     end
44     return key
45 end
46
47
48 local function init_consistent_hash_state(servers)
49     local weight_sum = 0
50     local weight = 1
51     for _, srv in ipairs(servers) do
52         if srv.weight  and srv.weight ~= 0 then
53             weight = srv.weight
54         end
55         weight_sum = weight_sum + weight
56     end
57
58     local circle, members = {}, 0
59     for index, srv in ipairs(servers) do
60         local key = ("%s:%s"):format(srv.ip, srv.port)
61         local base_hash = hash_string(key)
62         for c = 1, REPLICAS * weight_sum do
63             local hash = (base_hash * c * LUCKY_NUM) % MOD
64             tab_insert(circle, { hash, index })
65         end
66         
67         members = members + 1
68     end
69     tab_sort(circle, function(a, b) return a[1] < b[1] end)
70     return { circle = circle, members = members }
71 end
72
73 local function update_consistent_hash_state(hash_data,servers,svckey)
74     -- compare servers in ctx with servers in cache
75     -- update the hash data if changes occur
76     local serverscache = hash_data[svckey].servers
77     tab_sort(serverscache, function(a, b) return a.ip < b.ip end)
78     tab_sort(servers, function(a, b) return a.ip < b.ip end)
79     if  not tbl_isequal.deepcompare(serverscache, servers, false) then
80         local tmp_chash = init_consistent_hash_state(servers)
81         hash_data[svckey].servers =servers
82         hash_data[svckey].chash = tmp_chash
83     end
84 end
85
86 local function binary_search(circle, key)
87     local size = #circle
88     local st, ed, mid = 1, size
89
90     while st <= ed do
91         mid = floor((st + ed) / 2)
92         if circle[mid][1] < key then
93             st = mid + 1
94         else
95             ed = mid - 1
96         end
97     end
98
99     return st == size + 1 and 1 or st
100 end
101
102
103 function _M.select_backserver(servers,svckey)
104
105     if hash_data[svckey] == nil then
106         local tbl = {}
107         tbl['servers'] = {}
108         tbl['chash'] = {}
109         hash_data[svckey] = tbl
110     end
111
112     if tbl_isempty(hash_data[svckey].servers) then
113         local tmp_chash = init_consistent_hash_state(servers)
114         hash_data[svckey].servers = servers
115         hash_data[svckey].chash = tmp_chash
116     else
117         update_consistent_hash_state(hash_data,servers,svckey)
118     end
119
120     local chash = hash_data[svckey].chash
121     local circle = chash.circle
122     local hash_key = ngx_var.remote_addr
123     local st = binary_search(circle, hash_string(hash_key))
124     local size = #circle
125     local ed = st + size - 1
126     for i = st, ed do
127         local idx = circle[(i - 1) % size + 1][2]
128         if peerwatcher.is_server_ok(svckey,hash_data[svckey].servers[idx]) then
129             return hash_data[svckey].servers[idx]
130         end
131     end
132     return nil, "consistent hash: no servers available"
133 end
134
135 return _M