mirror of
https://github.com/empayre/fleet.git
synced 2024-11-06 17:05:18 +00:00
123 lines
2.7 KiB
Go
123 lines
2.7 KiB
Go
package redis
|
|
|
|
import (
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/fleetdm/fleet/v4/server/fleet"
|
|
"github.com/gomodule/redigo/redis"
|
|
)
|
|
|
|
type ThrottledStore struct {
|
|
Pool fleet.RedisPool
|
|
KeyPrefix string
|
|
}
|
|
|
|
const (
|
|
getWithTimeScript = `
|
|
local tbl = redis.call('TIME')
|
|
local val = redis.call('GET', KEYS[1])
|
|
table.insert(tbl, val)
|
|
return tbl
|
|
`
|
|
|
|
compareAndSwapWithTTLScript = `
|
|
local v = redis.call('get', KEYS[1])
|
|
if v == false then
|
|
return redis.error_reply("key does not exist")
|
|
end
|
|
if v ~= ARGV[1] then
|
|
return 0
|
|
end
|
|
redis.call('SET', KEYS[1], ARGV[2], 'EX', ARGV[3])
|
|
return 1
|
|
`
|
|
|
|
compareAndSwapNoKeyError = "key does not exist"
|
|
)
|
|
|
|
func (s *ThrottledStore) GetWithTime(key string) (int64, time.Time, error) {
|
|
var t time.Time
|
|
|
|
key = s.KeyPrefix + key
|
|
|
|
conn := s.Pool.Get()
|
|
defer conn.Close()
|
|
if err := BindConn(s.Pool, conn, key); err != nil {
|
|
return 0, t, err
|
|
}
|
|
// must come after BindConn due to redisc restrictions
|
|
conn = ConfigureDoer(s.Pool, conn)
|
|
|
|
script := redis.NewScript(1, getWithTimeScript)
|
|
res, err := redis.Values(script.Do(conn, key))
|
|
if err != nil {
|
|
return 0, t, err
|
|
}
|
|
if len(res) < 3 {
|
|
res = append(res, nil)
|
|
}
|
|
|
|
var secs, us, val int64
|
|
val = -1 // initialize val to -1, will stay untouched if res[2] is nil
|
|
if _, err := redis.Scan(res, &secs, &us, &val); err != nil {
|
|
return 0, t, err
|
|
}
|
|
t = time.Unix(secs, us*int64(time.Microsecond))
|
|
|
|
return val, t, nil
|
|
}
|
|
|
|
func (s *ThrottledStore) SetIfNotExistsWithTTL(key string, value int64, ttl time.Duration) (bool, error) {
|
|
key = s.KeyPrefix + key
|
|
|
|
conn := ConfigureDoer(s.Pool, s.Pool.Get())
|
|
defer conn.Close()
|
|
|
|
ttlSeconds := int(ttl.Seconds())
|
|
// An `EX 0` will fail, make sure that we set expiry for a minimum of one second
|
|
if ttlSeconds < 1 {
|
|
ttlSeconds = 1
|
|
}
|
|
|
|
_, err := redis.String(conn.Do("SET", key, value, "EX", ttlSeconds, "NX"))
|
|
if err != nil {
|
|
if err == redis.ErrNil {
|
|
// not set due to NX condition not met
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
|
|
return true, nil
|
|
}
|
|
|
|
func (s *ThrottledStore) CompareAndSwapWithTTL(key string, old, new int64, ttl time.Duration) (bool, error) {
|
|
key = s.KeyPrefix + key
|
|
|
|
conn := s.Pool.Get()
|
|
defer conn.Close()
|
|
if err := BindConn(s.Pool, conn, key); err != nil {
|
|
return false, err
|
|
}
|
|
// must come after BindConn due to redisc restrictions
|
|
conn = ConfigureDoer(s.Pool, conn)
|
|
|
|
ttlSeconds := int(ttl.Seconds())
|
|
// An `EX 0` will fail, make sure that we set expiry for a minimum of one second
|
|
if ttlSeconds < 1 {
|
|
ttlSeconds = 1
|
|
}
|
|
|
|
script := redis.NewScript(1, compareAndSwapWithTTLScript)
|
|
swapped, err := redis.Bool(script.Do(conn, key, old, new, ttlSeconds))
|
|
if err != nil {
|
|
if strings.Contains(err.Error(), compareAndSwapNoKeyError) {
|
|
return false, nil
|
|
}
|
|
return false, err
|
|
}
|
|
|
|
return swapped, nil
|
|
}
|