12 Commits

Author SHA1 Message Date
Marc Magnin
cf77eaf346 remove subscriptions when a client disconnect 2019-01-18 14:38:51 +01:00
Marc Magnin
7c4d7a0c06 simple fix 2019-01-03 21:23:14 +01:00
joyz
2b56664d85 remove no use 2018-12-27 21:22:32 +08:00
joy.zhou
7547ad3bdc Restruct (#34)
* modify

* remove

* modify

* modify

* remove no use

* add online/offline notification

* modify

* format log

* add reference
2018-12-26 14:51:13 +08:00
joy.zhou
84e7fe2490 context (#28) 2018-05-10 13:13:36 +08:00
zhouyuyan
684584b208 fix write logic 2018-04-28 09:37:37 +08:00
zhouyuyan
56fb4a2d54 fix issue 25 2018-04-28 09:08:28 +08:00
joy.zhou
5ed4728575 Wpool (#23)
* pool

* pool

* wpool
2018-04-04 13:49:52 +08:00
zhouyuyan
c0fea6a5ba modify_message_pool 2018-02-24 13:19:43 +08:00
zhouyuyan
47500910e1 fix broker out painc 2018-02-06 11:01:06 +08:00
joy.zhou
0ff20b6ee2 Update README.md 2018-02-03 13:11:53 +08:00
joy.zhou
7155667f6c Pool (#16)
* add pool

* elastic workerpool

* del buf

* modify usage

* modify readme
2018-02-03 12:42:25 +08:00
17 changed files with 1389 additions and 832 deletions

View File

@@ -28,7 +28,7 @@ Broker Options:
Logging Options: Logging Options:
-d, --debug <bool> Enable debugging output (default false) -d, --debug <bool> Enable debugging output (default false)
-D Debug and trace -D Debug enabled
Cluster Options: Cluster Options:
-r, --router <rurl> Router who maintenance cluster info -r, --router <rurl> Router who maintenance cluster info
@@ -77,8 +77,6 @@ Common Options:
* Supports will messages * Supports will messages
* Queue subscribe
* Websocket Support * Websocket Support
* TLS/SSL Support * TLS/SSL Support
@@ -95,13 +93,6 @@ Common Options:
``` ```
### QUEUE SUBSCRIBE
~~~
| Prefix | Examples |
| ------------- |---------------------------------|
| $queue/ | mosquitto_sub -t $queue/topic |
~~~
### ACL Configure ### ACL Configure
#### The ACL rules define: #### The ACL rules define:
~~~ ~~~
@@ -154,6 +145,14 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | -->
allow | deny allow | deny allow | deny allow | deny allow | deny allow | deny
~~~ ~~~
### Online/Offline Notification
```bash
topic:
$SYS/broker/connection/clients/<clientID>
payload:
{"clientID":"client001","online":true/false,"timestamp":"2018-10-25T09:32:32Z"}
```
## Performance ## Performance
* High throughput * High throughput
@@ -166,3 +165,8 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | -->
## License ## License
* Apache License Version 2.0 * Apache License Version 2.0
## Reference
* Surgermq.(https://github.com/surgemq/surgemq)

View File

@@ -3,13 +3,10 @@
package broker package broker
import ( import (
"strings"
"github.com/fhmq/hmq/lib/acl" "github.com/fhmq/hmq/lib/acl"
"go.uber.org/zap"
"github.com/fsnotify/fsnotify" "github.com/fsnotify/fsnotify"
"go.uber.org/zap"
"strings"
) )
const ( const (
@@ -44,10 +41,10 @@ func (b *Broker) handleFsEvent(event fsnotify.Event) error {
case b.config.AclConf: case b.config.AclConf:
if event.Op&fsnotify.Write == fsnotify.Write || if event.Op&fsnotify.Write == fsnotify.Write ||
event.Op&fsnotify.Create == fsnotify.Create { event.Op&fsnotify.Create == fsnotify.Create {
brokerLog.Info("text:handling acl config change event:", zap.String("filename", event.Name)) log.Info("text:handling acl config change event:", zap.String("filename", event.Name))
aclconfig, err := acl.AclConfigLoad(event.Name) aclconfig, err := acl.AclConfigLoad(event.Name)
if err != nil { if err != nil {
brokerLog.Error("aclconfig change failed, load acl conf error: ", zap.Error(err)) log.Error("aclconfig change failed, load acl conf error: ", zap.Error(err))
return err return err
} }
b.AclConfig = aclconfig b.AclConfig = aclconfig
@@ -60,24 +57,24 @@ func (b *Broker) StartAclWatcher() {
go func() { go func() {
wch, e := fsnotify.NewWatcher() wch, e := fsnotify.NewWatcher()
if e != nil { if e != nil {
brokerLog.Error("start monitor acl config file error,", zap.Error(e)) log.Error("start monitor acl config file error,", zap.Error(e))
return return
} }
defer wch.Close() defer wch.Close()
for _, i := range watchList { for _, i := range watchList {
if err := wch.Add(i); err != nil { if err := wch.Add(i); err != nil {
brokerLog.Error("start monitor acl config file error,", zap.Error(err)) log.Error("start monitor acl config file error,", zap.Error(err))
return return
} }
} }
brokerLog.Info("watching acl config file change...") log.Info("watching acl config file change...")
for { for {
select { select {
case evt := <-wch.Events: case evt := <-wch.Events:
b.handleFsEvent(evt) b.handleFsEvent(evt)
case err := <-wch.Errors: case err := <-wch.Errors:
brokerLog.Error("error:", zap.Error(err)) log.Error("error:", zap.Error(err))
} }
} }
}() }()

View File

@@ -4,6 +4,7 @@ package broker
import ( import (
"crypto/tls" "crypto/tls"
"fmt"
"net" "net"
"net/http" "net/http"
"runtime/debug" "runtime/debug"
@@ -11,18 +12,19 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/fhmq/hmq/lib/acl"
"github.com/fhmq/hmq/pool"
"github.com/eclipse/paho.mqtt.golang/packets" "github.com/eclipse/paho.mqtt.golang/packets"
"github.com/fhmq/hmq/lib/acl"
"github.com/fhmq/hmq/lib/sessions"
"github.com/fhmq/hmq/lib/topics"
"github.com/fhmq/hmq/pool"
"github.com/shirou/gopsutil/mem" "github.com/shirou/gopsutil/mem"
"go.uber.org/zap" "go.uber.org/zap"
"golang.org/x/net/websocket" "golang.org/x/net/websocket"
) )
var ( const (
brokerLog *zap.Logger MessagePoolNum = 1024
MessagePoolMessageNum = 1024
) )
type Message struct { type Message struct {
@@ -43,10 +45,19 @@ type Broker struct {
remotes sync.Map remotes sync.Map
nodes map[string]interface{} nodes map[string]interface{}
clusterPool chan *Message clusterPool chan *Message
messagePool chan *Message
sl *Sublist
rl *RetainList
queues map[string]int queues map[string]int
topicsMgr *topics.Manager
sessionMgr *sessions.Manager
// messagePool []chan *Message
}
func newMessagePool() []chan *Message {
pool := make([]chan *Message, 0)
for i := 0; i < MessagePoolNum; i++ {
ch := make(chan *Message, MessagePoolMessageNum)
pool = append(pool, ch)
}
return pool
} }
func NewBroker(config *Config) (*Broker, error) { func NewBroker(config *Config) (*Broker, error) {
@@ -54,17 +65,28 @@ func NewBroker(config *Config) (*Broker, error) {
id: GenUniqueId(), id: GenUniqueId(),
config: config, config: config,
wpool: pool.New(config.Worker), wpool: pool.New(config.Worker),
sl: NewSublist(),
rl: NewRetainList(),
nodes: make(map[string]interface{}), nodes: make(map[string]interface{}),
queues: make(map[string]int), queues: make(map[string]int),
clusterPool: make(chan *Message), clusterPool: make(chan *Message),
messagePool: make(chan *Message),
} }
var err error
b.topicsMgr, err = topics.NewManager("mem")
if err != nil {
log.Error("new topic manager error", zap.Error(err))
return nil, err
}
b.sessionMgr, err = sessions.NewManager("mem")
if err != nil {
log.Error("new session manager error", zap.Error(err))
return nil, err
}
if b.config.TlsPort != "" { if b.config.TlsPort != "" {
tlsconfig, err := NewTLSConfig(b.config.TlsInfo) tlsconfig, err := NewTLSConfig(b.config.TlsInfo)
if err != nil { if err != nil {
brokerLog.Error("new tlsConfig error", zap.Error(err)) log.Error("new tlsConfig error", zap.Error(err))
return nil, err return nil, err
} }
b.tlsConfig = tlsconfig b.tlsConfig = tlsconfig
@@ -72,7 +94,7 @@ func NewBroker(config *Config) (*Broker, error) {
if b.config.Acl { if b.config.Acl {
aclconfig, err := acl.AclConfigLoad(b.config.AclConf) aclconfig, err := acl.AclConfigLoad(b.config.AclConf)
if err != nil { if err != nil {
brokerLog.Error("Load acl conf error", zap.Error(err)) log.Error("Load acl conf error", zap.Error(err))
return nil, err return nil, err
} }
b.AclConfig = aclconfig b.AclConfig = aclconfig
@@ -81,13 +103,14 @@ func NewBroker(config *Config) (*Broker, error) {
return b, nil return b, nil
} }
func (b *Broker) StartDispatcher() { func (b *Broker) SubmitWork(msg *Message) {
for { if b.wpool == nil {
msg, ok := <-b.messagePool b.wpool = pool.New(b.config.Worker)
if !ok { }
brokerLog.Error("read message from client channel error")
return if msg.client.typ == CLUSTER {
} b.clusterPool <- msg
} else {
b.wpool.Submit(func() { b.wpool.Submit(func() {
ProcessMessage(msg) ProcessMessage(msg)
}) })
@@ -97,10 +120,9 @@ func (b *Broker) StartDispatcher() {
func (b *Broker) Start() { func (b *Broker) Start() {
if b == nil { if b == nil {
brokerLog.Error("broker is null") log.Error("broker is null")
return return
} }
go b.StartDispatcher()
//listen clinet over tcp //listen clinet over tcp
if b.config.Port != "" { if b.config.Port != "" {
@@ -149,7 +171,7 @@ func StateMonitor() {
func (b *Broker) StartWebsocketListening() { func (b *Broker) StartWebsocketListening() {
path := b.config.WsPath path := b.config.WsPath
hp := ":" + b.config.WsPort hp := ":" + b.config.WsPort
brokerLog.Info("Start Websocket Listener on:", zap.String("hp", hp), zap.String("path", path)) log.Info("Start Websocket Listener on:", zap.String("hp", hp), zap.String("path", path))
http.Handle(path, websocket.Handler(b.wsHandler)) http.Handle(path, websocket.Handler(b.wsHandler))
var err error var err error
if b.config.WsTLS { if b.config.WsTLS {
@@ -158,7 +180,7 @@ func (b *Broker) StartWebsocketListening() {
err = http.ListenAndServe(hp, nil) err = http.ListenAndServe(hp, nil)
} }
if err != nil { if err != nil {
brokerLog.Error("ListenAndServe:" + err.Error()) log.Error("ListenAndServe:" + err.Error())
return return
} }
} }
@@ -167,7 +189,7 @@ func (b *Broker) wsHandler(ws *websocket.Conn) {
// io.Copy(ws, ws) // io.Copy(ws, ws)
atomic.AddUint64(&b.cid, 1) atomic.AddUint64(&b.cid, 1)
ws.PayloadType = websocket.BinaryFrame ws.PayloadType = websocket.BinaryFrame
b.handleConnection(CLIENT, ws, b.cid) b.handleConnection(CLIENT, ws)
} }
func (b *Broker) StartClientListening(Tls bool) { func (b *Broker) StartClientListening(Tls bool) {
@@ -177,14 +199,14 @@ func (b *Broker) StartClientListening(Tls bool) {
if Tls { if Tls {
hp = b.config.TlsHost + ":" + b.config.TlsPort hp = b.config.TlsHost + ":" + b.config.TlsPort
l, err = tls.Listen("tcp", hp, b.tlsConfig) l, err = tls.Listen("tcp", hp, b.tlsConfig)
brokerLog.Info("Start TLS Listening client on ", zap.String("hp", hp)) log.Info("Start TLS Listening client on ", zap.String("hp", hp))
} else { } else {
hp := b.config.Host + ":" + b.config.Port hp := b.config.Host + ":" + b.config.Port
l, err = net.Listen("tcp", hp) l, err = net.Listen("tcp", hp)
brokerLog.Info("Start Listening client on ", zap.String("hp", hp)) log.Info("Start Listening client on ", zap.String("hp", hp))
} }
if err != nil { if err != nil {
brokerLog.Error("Error listening on ", zap.Error(err)) log.Error("Error listening on ", zap.Error(err))
return return
} }
tmpDelay := 10 * ACCEPT_MIN_SLEEP tmpDelay := 10 * ACCEPT_MIN_SLEEP
@@ -192,7 +214,7 @@ func (b *Broker) StartClientListening(Tls bool) {
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() { if ne, ok := err.(net.Error); ok && ne.Temporary() {
brokerLog.Error("Temporary Client Accept Error(%v), sleeping %dms", log.Error("Temporary Client Accept Error(%v), sleeping %dms",
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond)) zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
time.Sleep(tmpDelay) time.Sleep(tmpDelay)
tmpDelay *= 2 tmpDelay *= 2
@@ -200,13 +222,13 @@ func (b *Broker) StartClientListening(Tls bool) {
tmpDelay = ACCEPT_MAX_SLEEP tmpDelay = ACCEPT_MAX_SLEEP
} }
} else { } else {
brokerLog.Error("Accept error: %v", zap.Error(err)) log.Error("Accept error: %v", zap.Error(err))
} }
continue continue
} }
tmpDelay = ACCEPT_MIN_SLEEP tmpDelay = ACCEPT_MIN_SLEEP
atomic.AddUint64(&b.cid, 1) atomic.AddUint64(&b.cid, 1)
go b.handleConnection(CLIENT, conn, b.cid) go b.handleConnection(CLIENT, conn)
} }
} }
@@ -219,7 +241,7 @@ func (b *Broker) Handshake(conn net.Conn) bool {
// Force handshake // Force handshake
if err := nc.Handshake(); err != nil { if err := nc.Handshake(); err != nil {
brokerLog.Error("TLS handshake error, ", zap.Error(err)) log.Error("TLS handshake error, ", zap.Error(err))
return false return false
} }
nc.SetReadDeadline(time.Time{}) nc.SetReadDeadline(time.Time{})
@@ -235,28 +257,27 @@ func TlsTimeout(conn *tls.Conn) {
} }
cs := nc.ConnectionState() cs := nc.ConnectionState()
if !cs.HandshakeComplete { if !cs.HandshakeComplete {
brokerLog.Error("TLS handshake timeout") log.Error("TLS handshake timeout")
nc.Close() nc.Close()
} }
} }
func (b *Broker) StartClusterListening() { func (b *Broker) StartClusterListening() {
var hp string = b.config.Cluster.Host + ":" + b.config.Cluster.Port var hp string = b.config.Cluster.Host + ":" + b.config.Cluster.Port
brokerLog.Info("Start Listening cluster on ", zap.String("hp", hp)) log.Info("Start Listening cluster on ", zap.String("hp", hp))
l, e := net.Listen("tcp", hp) l, e := net.Listen("tcp", hp)
if e != nil { if e != nil {
brokerLog.Error("Error listening on ", zap.Error(e)) log.Error("Error listening on ", zap.Error(e))
return return
} }
var idx uint64 = 0
tmpDelay := 10 * ACCEPT_MIN_SLEEP tmpDelay := 10 * ACCEPT_MIN_SLEEP
for { for {
conn, err := l.Accept() conn, err := l.Accept()
if err != nil { if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() { if ne, ok := err.(net.Error); ok && ne.Temporary() {
brokerLog.Error("Temporary Client Accept Error(%v), sleeping %dms", log.Error("Temporary Client Accept Error(%v), sleeping %dms",
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond)) zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
time.Sleep(tmpDelay) time.Sleep(tmpDelay)
tmpDelay *= 2 tmpDelay *= 2
@@ -264,30 +285,30 @@ func (b *Broker) StartClusterListening() {
tmpDelay = ACCEPT_MAX_SLEEP tmpDelay = ACCEPT_MAX_SLEEP
} }
} else { } else {
brokerLog.Error("Accept error: %v", zap.Error(err)) log.Error("Accept error: %v", zap.Error(err))
} }
continue continue
} }
tmpDelay = ACCEPT_MIN_SLEEP tmpDelay = ACCEPT_MIN_SLEEP
go b.handleConnection(ROUTER, conn, idx) go b.handleConnection(ROUTER, conn)
} }
} }
func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) { func (b *Broker) handleConnection(typ int, conn net.Conn) {
//process connect packet //process connect packet
packet, err := packets.ReadPacket(conn) packet, err := packets.ReadPacket(conn)
if err != nil { if err != nil {
brokerLog.Error("read connect packet error: ", zap.Error(err)) log.Error("read connect packet error: ", zap.Error(err))
return return
} }
if packet == nil { if packet == nil {
brokerLog.Error("received nil packet") log.Error("received nil packet")
return return
} }
msg, ok := packet.(*packets.ConnectPacket) msg, ok := packet.(*packets.ConnectPacket)
if !ok { if !ok {
brokerLog.Error("received msg that was not Connect") log.Error("received msg that was not Connect")
return return
} }
connack := packets.NewControlPacket(packets.Connack).(*packets.ConnackPacket) connack := packets.NewControlPacket(packets.Connack).(*packets.ConnackPacket)
@@ -295,7 +316,7 @@ func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
connack.SessionPresent = msg.CleanSession connack.SessionPresent = msg.CleanSession
err = connack.Write(conn) err = connack.Write(conn)
if err != nil { if err != nil {
brokerLog.Error("send connack error, ", zap.Error(err), zap.String("clientID", msg.ClientIdentifier)) log.Error("send connack error, ", zap.Error(err), zap.String("clientID", msg.ClientIdentifier))
return return
} }
@@ -326,6 +347,12 @@ func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
c.init() c.init()
err = b.getSession(c, msg, connack)
if err != nil {
log.Error("get session error: ", zap.String("clientID", c.info.clientID))
return
}
cid := c.info.clientID cid := c.info.clientID
var exist bool var exist bool
@@ -335,17 +362,19 @@ func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
case CLIENT: case CLIENT:
old, exist = b.clients.Load(cid) old, exist = b.clients.Load(cid)
if exist { if exist {
brokerLog.Warn("client exist, close old...", zap.String("clientID", c.info.clientID)) log.Warn("client exist, close old...", zap.String("clientID", c.info.clientID))
ol, ok := old.(*client) ol, ok := old.(*client)
if ok { if ok {
ol.Close() ol.Close()
} }
} }
b.clients.Store(cid, c) b.clients.Store(cid, c)
b.OnlineOfflineNotification(cid, true)
case ROUTER: case ROUTER:
old, exist = b.routes.Load(cid) old, exist = b.routes.Load(cid)
if exist { if exist {
brokerLog.Warn("router exist, close old...") log.Warn("router exist, close old...")
ol, ok := old.(*client) ol, ok := old.(*client)
if ok { if ok {
ol.Close() ol.Close()
@@ -354,7 +383,9 @@ func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
b.routes.Store(cid, c) b.routes.Store(cid, c)
} }
c.readLoop(b.messagePool) // mpool := b.messagePool[fnv1a.HashString64(cid)%MessagePoolNum]
c.readLoop()
} }
func (b *Broker) ConnectToDiscovery() { func (b *Broker) ConnectToDiscovery() {
@@ -364,8 +395,8 @@ func (b *Broker) ConnectToDiscovery() {
for { for {
conn, err = net.Dial("tcp", b.config.Router) conn, err = net.Dial("tcp", b.config.Router)
if err != nil { if err != nil {
brokerLog.Error("Error trying to connect to route: ", zap.Error(err)) log.Error("Error trying to connect to route: ", zap.Error(err))
brokerLog.Debug("Connect to route timeout ,retry...") log.Debug("Connect to route timeout ,retry...")
if 0 == tempDelay { if 0 == tempDelay {
tempDelay = 1 * time.Second tempDelay = 1 * time.Second
@@ -381,7 +412,7 @@ func (b *Broker) ConnectToDiscovery() {
} }
break break
} }
brokerLog.Debug("connect to router success :", zap.String("Router", b.config.Router)) log.Debug("connect to router success :", zap.String("Router", b.config.Router))
cid := b.id cid := b.id
info := info{ info := info{
@@ -401,7 +432,7 @@ func (b *Broker) ConnectToDiscovery() {
c.SendConnect() c.SendConnect()
c.SendInfo() c.SendInfo()
go c.readLoop(b.clusterPool) go c.readLoop()
go c.StartPing() go c.StartPing()
} }
@@ -409,7 +440,7 @@ func (b *Broker) processClusterInfo() {
for { for {
msg, ok := <-b.clusterPool msg, ok := <-b.clusterPool
if !ok { if !ok {
brokerLog.Error("read message from cluster channel error") log.Error("read message from cluster channel error")
return return
} }
ProcessMessage(msg) ProcessMessage(msg)
@@ -431,13 +462,13 @@ func (b *Broker) connectRouter(id, addr string) {
conn, err = net.Dial("tcp", addr) conn, err = net.Dial("tcp", addr)
if err != nil { if err != nil {
brokerLog.Error("Error trying to connect to route: ", zap.Error(err)) log.Error("Error trying to connect to route: ", zap.Error(err))
if retryTimes > 50 { if retryTimes > 50 {
return return
} }
brokerLog.Debug("Connect to route timeout ,retry...") log.Debug("Connect to route timeout ,retry...")
if 0 == timeDelay { if 0 == timeDelay {
timeDelay = 1 * time.Second timeDelay = 1 * time.Second
@@ -477,7 +508,8 @@ func (b *Broker) connectRouter(id, addr string) {
c.SendConnect() c.SendConnect()
go c.readLoop(b.messagePool) // mpool := b.messagePool[fnv1a.HashString64(cid)%MessagePoolNum]
go c.readLoop()
go c.StartPing() go c.StartPing()
} }
@@ -525,9 +557,9 @@ func (b *Broker) SendLocalSubsToRouter(c *client) {
b.clients.Range(func(key, value interface{}) bool { b.clients.Range(func(key, value interface{}) bool {
client, ok := value.(*client) client, ok := value.(*client)
if ok { if ok {
subs := client.subs subs := client.subMap
for _, sub := range subs { for _, sub := range subs {
subInfo.Topics = append(subInfo.Topics, string(sub.topic)) subInfo.Topics = append(subInfo.Topics, sub.topic)
subInfo.Qoss = append(subInfo.Qoss, sub.qos) subInfo.Qoss = append(subInfo.Qoss, sub.qos)
} }
} }
@@ -536,7 +568,7 @@ func (b *Broker) SendLocalSubsToRouter(c *client) {
if len(subInfo.Topics) > 0 { if len(subInfo.Topics) > 0 {
err := c.WriterPacket(subInfo) err := c.WriterPacket(subInfo)
if err != nil { if err != nil {
brokerLog.Error("Send localsubs To Router error :", zap.Error(err)) log.Error("Send localsubs To Router error :", zap.Error(err))
} }
} }
} }
@@ -553,7 +585,7 @@ func (b *Broker) BroadcastInfoMessage(remoteID string, msg *packets.PublishPacke
return true return true
}) })
// brokerLog.Info("BroadcastInfoMessage success ") // log.Info("BroadcastInfoMessage success ")
} }
func (b *Broker) BroadcastSubOrUnsubMessage(packet packets.ControlPacket) { func (b *Broker) BroadcastSubOrUnsubMessage(packet packets.ControlPacket) {
@@ -565,7 +597,7 @@ func (b *Broker) BroadcastSubOrUnsubMessage(packet packets.ControlPacket) {
} }
return true return true
}) })
// brokerLog.Info("BroadcastSubscribeMessage remotes: ", s.remotes) // log.Info("BroadcastSubscribeMessage remotes: ", s.remotes)
} }
func (b *Broker) removeClient(c *client) { func (b *Broker) removeClient(c *client) {
@@ -579,21 +611,26 @@ func (b *Broker) removeClient(c *client) {
case REMOTE: case REMOTE:
b.remotes.Delete(clientId) b.remotes.Delete(clientId)
} }
// brokerLog.Info("delete client ,", clientId) // log.Info("delete client ,", clientId)
} }
func (b *Broker) PublishMessage(packet *packets.PublishPacket) { func (b *Broker) PublishMessage(packet *packets.PublishPacket) {
topic := packet.TopicName var subs []interface{}
r := b.sl.Match(topic) var qoss []byte
if len(r.psubs) == 0 { b.mu.Lock()
err := b.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &subs, &qoss)
b.mu.Unlock()
if err != nil {
log.Error("search sub client error, ", zap.Error(err))
return return
} }
for _, sub := range r.psubs { for _, sub := range subs {
if sub != nil { s, ok := sub.(*subscription)
err := sub.client.WriterPacket(packet) if ok {
err := s.client.WriterPacket(packet)
if err != nil { if err != nil {
brokerLog.Error("process message for psub error, ", zap.Error(err)) log.Error("write message error, ", zap.Error(err))
} }
} }
} }
@@ -610,3 +647,12 @@ func (b *Broker) BroadcastUnSubscribe(subs map[string]*subscription) {
b.BroadcastSubOrUnsubMessage(unsub) b.BroadcastSubOrUnsubMessage(unsub)
} }
} }
func (b *Broker) OnlineOfflineNotification(clientID string, online bool) {
packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
packet.TopicName = "$SYS/broker/connection/clients/" + clientID
packet.Qos = 0
packet.Payload = []byte(fmt.Sprintf(`{"clientID":"%s","online":%v,"timestamp":"%s"}`, clientID, online, time.Now().UTC().Format(time.RFC3339)))
b.PublishMessage(packet)
}

View File

@@ -3,6 +3,8 @@
package broker package broker
import ( import (
"context"
"errors"
"net" "net"
"reflect" "reflect"
"strings" "strings"
@@ -10,6 +12,8 @@ import (
"time" "time"
"github.com/eclipse/paho.mqtt.golang/packets" "github.com/eclipse/paho.mqtt.golang/packets"
"github.com/fhmq/hmq/lib/sessions"
"github.com/fhmq/hmq/lib/topics"
"go.uber.org/zap" "go.uber.org/zap"
) )
@@ -30,22 +34,21 @@ const (
) )
type client struct { type client struct {
typ int typ int
mu sync.Mutex mu sync.Mutex
broker *Broker broker *Broker
conn net.Conn conn net.Conn
info info info info
route route route route
status int status int
closed chan int ctx context.Context
smu sync.RWMutex cancelFunc context.CancelFunc
subs map[string]*subscription session *sessions.Session
rsubs map[string]*subInfo subMap map[string]*subscription
} topicsMgr *topics.Manager
subs []interface{}
type subInfo struct { qoss []byte
sub *subscription rmsgs []*packets.PublishPacket
num int
} }
type subscription struct { type subscription struct {
@@ -75,71 +78,59 @@ var (
) )
func (c *client) init() { func (c *client) init() {
c.smu.Lock()
defer c.smu.Unlock()
c.status = Connected c.status = Connected
c.closed = make(chan int, 1)
c.rsubs = make(map[string]*subInfo)
c.subs = make(map[string]*subscription, 10)
c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0] c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0]
c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0] c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0]
c.ctx, c.cancelFunc = context.WithCancel(context.Background())
c.subMap = make(map[string]*subscription)
c.topicsMgr = c.broker.topicsMgr
} }
func (c *client) keepAlive(ch chan int, mpool chan *Message) { func (c *client) readLoop() {
defer close(ch)
keepalive := time.Duration(c.info.keepalive*3/2) * time.Second
timer := time.NewTimer(keepalive)
for {
select {
case <-ch:
timer.Reset(keepalive)
case <-timer.C:
if c.typ == REMOTE || c.typ == CLUSTER {
timer.Reset(keepalive)
continue
}
brokerLog.Error("Client exceeded timeout, disconnecting. ", zap.String("ClientID", c.info.clientID), zap.Uint16("keepalive", c.info.keepalive))
msg := &Message{client: c, packet: DisconnectdPacket}
mpool <- msg
timer.Stop()
return
case _, ok := <-c.closed:
if !ok {
return
}
}
}
}
func (c *client) readLoop(mpool chan *Message) {
nc := c.conn nc := c.conn
if nc == nil || mpool == nil { b := c.broker
if nc == nil || b == nil {
return return
} }
ch := make(chan int, 1000) keepAlive := time.Second * time.Duration(c.info.keepalive)
go c.keepAlive(ch, mpool) timeOut := keepAlive + (keepAlive / 2)
for { for {
packet, err := packets.ReadPacket(nc) select {
if err != nil { case <-c.ctx.Done():
brokerLog.Error("read packet error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) return
break default:
} //add read timeout
// keepalive channel if err := nc.SetReadDeadline(time.Now().Add(timeOut)); err != nil {
ch <- 1 log.Error("set read timeout error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return
}
msg := &Message{ packet, err := packets.ReadPacket(nc)
client: c, if err != nil {
packet: packet, log.Error("read packet error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
msg := &Message{client: c, packet: DisconnectdPacket}
b.SubmitWork(msg)
// remove subscriptions related to that client
for topic, sub := range c.subMap {
t := []byte(topic)
c.topicsMgr.Unsubscribe(t, sub)
c.session.RemoveTopic(topic)
delete(c.subMap, topic)
}
return
}
msg := &Message{
client: c,
packet: packet,
}
b.SubmitWork(msg)
} }
mpool <- msg
} }
msg := &Message{client: c, packet: DisconnectdPacket}
mpool <- msg
} }
func ProcessMessage(msg *Message) { func ProcessMessage(msg *Message) {
@@ -148,8 +139,7 @@ func ProcessMessage(msg *Message) {
if ca == nil { if ca == nil {
return return
} }
log.Debug("Recv message:", zap.String("message type", reflect.TypeOf(msg.packet).String()[9:]), zap.String("ClientID", c.info.clientID))
brokerLog.Debug("Recv message:", zap.String("message type", reflect.TypeOf(msg.packet).String()[9:]), zap.String("ClientID", c.info.clientID))
switch ca.(type) { switch ca.(type) {
case *packets.ConnackPacket: case *packets.ConnackPacket:
case *packets.ConnectPacket: case *packets.ConnectPacket:
@@ -174,7 +164,7 @@ func ProcessMessage(msg *Message) {
case *packets.DisconnectPacket: case *packets.DisconnectPacket:
c.Close() c.Close()
default: default:
brokerLog.Info("Recv Unknow message.......", zap.String("ClientID", c.info.clientID)) log.Info("Recv Unknow message.......", zap.String("ClientID", c.info.clientID))
} }
} }
@@ -190,7 +180,7 @@ func (c *client) ProcessPublish(packet *packets.PublishPacket) {
} }
if !c.CheckTopicAuth(PUB, topic) { if !c.CheckTopicAuth(PUB, topic) {
brokerLog.Error("Pub Topics Auth failed, ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID)) log.Error("Pub Topics Auth failed, ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
return return
} }
@@ -201,24 +191,16 @@ func (c *client) ProcessPublish(packet *packets.PublishPacket) {
puback := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket) puback := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
puback.MessageID = packet.MessageID puback.MessageID = packet.MessageID
if err := c.WriterPacket(puback); err != nil { if err := c.WriterPacket(puback); err != nil {
brokerLog.Error("send puback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) log.Error("send puback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return return
} }
c.ProcessPublishMessage(packet) c.ProcessPublishMessage(packet)
case QosExactlyOnce: case QosExactlyOnce:
return return
default: default:
brokerLog.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID)) log.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID))
return return
} }
if packet.Retain {
if b := c.broker; b != nil {
err := b.rl.Insert(topic, packet)
if err != nil {
brokerLog.Error("Insert Retain Message error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
}
} }
@@ -232,81 +214,42 @@ func (c *client) ProcessPublishMessage(packet *packets.PublishPacket) {
return return
} }
typ := c.typ typ := c.typ
topic := packet.TopicName
r := b.sl.Match(topic) if packet.Retain {
// brokerLog.Info("psubs num: ", len(r.psubs)) if err := c.topicsMgr.Retain(packet); err != nil {
if len(r.qsubs) == 0 && len(r.psubs) == 0 { log.Error("Error retaining message: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
c.mu.Lock()
err := c.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &c.subs, &c.qoss)
c.mu.Unlock()
if err != nil {
log.Error("Error retrieving subscribers list: ", zap.String("ClientID", c.info.clientID))
return return
} }
for _, sub := range r.psubs { // log.Info("psubs num: ", len(r.psubs))
if sub.client.typ == ROUTER { if len(c.subs) == 0 {
if typ != CLIENT { return
continue
}
}
if sub != nil {
err := sub.client.WriterPacket(packet)
if err != nil {
brokerLog.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
} }
pre := -1 for _, sub := range c.subs {
now := -1 s, ok := sub.(*subscription)
t := "$queue/" + topic if ok {
cnt, exist := b.queues[t] if s.client.typ == ROUTER {
if exist {
// brokerLog.Info("queue index : ", cnt)
for _, sub := range r.qsubs {
if sub.client.typ == ROUTER {
if typ != CLIENT { if typ != CLIENT {
continue continue
} }
} }
if c.typ == CLIENT { err := s.client.WriterPacket(packet)
now = now + 1 if err != nil {
} else { log.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
now = now + sub.client.rsubs[t].num
} }
if cnt > pre && cnt <= now {
if sub != nil {
err := sub.client.WriterPacket(packet)
if err != nil {
brokerLog.Error("send publish error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
break
}
pre = now
} }
} }
length := getQueueSubscribeNum(r.qsubs)
if length > 0 {
b.queues[t] = (b.queues[t] + 1) % length
}
}
func getQueueSubscribeNum(qsubs []*subscription) int {
topic := "$queue/"
if len(qsubs) < 1 {
return 0
} else {
topic = topic + qsubs[0].topic
}
num := 0
for _, sub := range qsubs {
if sub.client.typ == CLIENT {
num = num + 1
} else {
num = num + sub.client.rsubs[topic].num
}
}
return num
} }
func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) { func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) {
@@ -329,64 +272,34 @@ func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) {
t := topic t := topic
//check topic auth for client //check topic auth for client
if !c.CheckTopicAuth(SUB, topic) { if !c.CheckTopicAuth(SUB, topic) {
brokerLog.Error("Sub topic Auth failed: ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID)) log.Error("Sub topic Auth failed: ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
retcodes = append(retcodes, QosFailure) retcodes = append(retcodes, QosFailure)
continue continue
} }
queue := strings.HasPrefix(topic, "$queue/")
if queue {
if len(t) > 7 {
t = t[7:]
if _, exists := b.queues[topic]; !exists {
b.queues[topic] = 0
}
} else {
retcodes = append(retcodes, QosFailure)
continue
}
}
sub := &subscription{ sub := &subscription{
topic: t, topic: t,
qos: qoss[i], qos: qoss[i],
client: c, client: c,
queue: queue,
} }
switch c.typ {
case CLIENT:
if _, exist := c.subs[topic]; !exist {
c.subs[topic] = sub
} else { rqos, err := c.topicsMgr.Subscribe([]byte(topic), qoss[i], sub)
//if exist ,check whether qos change
c.subs[topic].qos = qoss[i]
retcodes = append(retcodes, qoss[i])
continue
}
case ROUTER:
if subinfo, exist := c.rsubs[topic]; !exist {
sinfo := &subInfo{sub: sub, num: 1}
c.rsubs[topic] = sinfo
} else {
subinfo.num = subinfo.num + 1
retcodes = append(retcodes, qoss[i])
continue
}
}
err := b.sl.Insert(sub)
if err != nil { if err != nil {
brokerLog.Error("Insert subscription error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) return
retcodes = append(retcodes, QosFailure)
} else {
retcodes = append(retcodes, qoss[i])
} }
c.subMap[topic] = sub
c.session.AddTopic(topic, qoss[i])
retcodes = append(retcodes, rqos)
c.topicsMgr.Retained([]byte(topic), &c.rmsgs)
} }
suback.ReturnCodes = retcodes suback.ReturnCodes = retcodes
err := c.WriterPacket(suback) err := c.WriterPacket(suback)
if err != nil { if err != nil {
brokerLog.Error("send suback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) log.Error("send suback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return return
} }
//broadcast subscribe message //broadcast subscribe message
@@ -395,13 +308,11 @@ func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) {
} }
//process retain message //process retain message
for _, t := range topics { for _, rm := range c.rmsgs {
packets := b.rl.Match(t) if err := c.WriterPacket(rm); err != nil {
for _, packet := range packets { log.Error("Error publishing retained message:", zap.Any("err", err), zap.String("ClientID", c.info.clientID))
brokerLog.Info("process retain message: ", zap.Any("packet", packet), zap.String("ClientID", c.info.clientID)) } else {
if packet != nil { log.Info("process retain message: ", zap.Any("packet", packet), zap.String("ClientID", c.info.clientID))
c.WriterPacket(packet)
}
} }
} }
} }
@@ -414,30 +325,16 @@ func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) {
if b == nil { if b == nil {
return return
} }
typ := c.typ
topics := packet.Topics topics := packet.Topics
for _, t := range topics { for _, topic := range topics {
t := []byte(topic)
switch typ { sub, exist := c.subMap[topic]
case CLIENT: if exist {
sub, ok := c.subs[t] c.topicsMgr.Unsubscribe(t, sub)
if ok { c.session.RemoveTopic(topic)
c.unsubscribe(sub) delete(c.subMap, topic)
}
case ROUTER:
subinfo, ok := c.rsubs[t]
if ok {
subinfo.num = subinfo.num - 1
if subinfo.num < 1 {
delete(c.rsubs, t)
c.unsubscribe(subinfo.sub)
} else {
c.rsubs[t] = subinfo
}
}
} }
} }
unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket) unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket)
@@ -445,7 +342,7 @@ func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) {
err := c.WriterPacket(unsuback) err := c.WriterPacket(unsuback)
if err != nil { if err != nil {
brokerLog.Error("send unsuback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) log.Error("send unsuback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return return
} }
// //process ubsubscribe message // //process ubsubscribe message
@@ -454,19 +351,6 @@ func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) {
} }
} }
func (c *client) unsubscribe(sub *subscription) {
if c.typ == CLIENT {
delete(c.subs, sub.topic)
}
b := c.broker
if b != nil && sub != nil {
b.sl.Remove(sub)
}
}
func (c *client) ProcessPing() { func (c *client) ProcessPing() {
if c.status == Disconnected { if c.status == Disconnected {
return return
@@ -474,43 +358,39 @@ func (c *client) ProcessPing() {
resp := packets.NewControlPacket(packets.Pingresp).(*packets.PingrespPacket) resp := packets.NewControlPacket(packets.Pingresp).(*packets.PingrespPacket)
err := c.WriterPacket(resp) err := c.WriterPacket(resp)
if err != nil { if err != nil {
brokerLog.Error("send PingResponse error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) log.Error("send PingResponse error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return return
} }
} }
func (c *client) Close() { func (c *client) Close() {
c.smu.Lock()
if c.status == Disconnected { if c.status == Disconnected {
c.smu.Unlock()
return return
} }
c.cancelFunc()
c.status = Disconnected
//wait for message complete //wait for message complete
time.Sleep(1 * time.Second) time.Sleep(1 * time.Second)
c.status = Disconnected // c.status = Disconnected
if c.conn != nil { if c.conn != nil {
c.conn.Close() c.conn.Close()
c.conn = nil c.conn = nil
} }
c.smu.Unlock()
close(c.closed)
b := c.broker b := c.broker
subs := c.subs subs := c.subMap
if b != nil { if b != nil {
b.removeClient(c) b.removeClient(c)
for _, sub := range subs {
err := b.sl.Remove(sub)
if err != nil {
brokerLog.Error("closed client but remove sublist error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
if c.typ == CLIENT { if c.typ == CLIENT {
b.BroadcastUnSubscribe(subs) b.BroadcastUnSubscribe(subs)
//offline notification
b.OnlineOfflineNotification(c.info.clientID, false)
} }
if c.info.willMsg != nil { if c.info.willMsg != nil {
b.PublishMessage(c.info.willMsg) b.PublishMessage(c.info.willMsg)
} }
@@ -527,9 +407,17 @@ func (c *client) Close() {
} }
func (c *client) WriterPacket(packet packets.ControlPacket) error { func (c *client) WriterPacket(packet packets.ControlPacket) error {
if c.status == Disconnected {
return nil
}
if packet == nil { if packet == nil {
return nil return nil
} }
if c.conn == nil {
c.Close()
return errors.New("connect lost ....")
}
c.mu.Lock() c.mu.Lock()
err := packet.Write(c.conn) err := packet.Write(c.conn)

View File

@@ -7,10 +7,8 @@ import (
"crypto/rand" "crypto/rand"
"encoding/base64" "encoding/base64"
"encoding/hex" "encoding/hex"
"errors"
"io" "io"
"reflect" "reflect"
"strings"
"time" "time"
) )
@@ -48,47 +46,6 @@ const (
QosFailure = 0x80 QosFailure = 0x80
) )
func SubscribeTopicCheckAndSpilt(topic string) ([]string, error) {
if strings.Index(topic, "#") != -1 && strings.Index(topic, "#") != len(topic)-1 {
return nil, errors.New("Topic format error with index of #")
}
re := strings.Split(topic, "/")
for i, v := range re {
if i != 0 && i != (len(re)-1) {
if v == "" {
return nil, errors.New("Topic format error with index of //")
}
if strings.Contains(v, "+") && v != "+" {
return nil, errors.New("Topic format error with index of +")
}
} else {
if v == "" {
re[i] = "/"
}
}
}
return re, nil
}
func PublishTopicCheckAndSpilt(topic string) ([]string, error) {
if strings.Index(topic, "#") != -1 || strings.Index(topic, "+") != -1 {
return nil, errors.New("Publish Topic format error with + and #")
}
re := strings.Split(topic, "/")
for i, v := range re {
if v == "" {
if i != 0 && i != (len(re)-1) {
return nil, errors.New("Topic format error with index of //")
} else {
re[i] = "/"
}
}
}
return re, nil
}
func equal(k1, k2 interface{}) bool { func equal(k1, k2 interface{}) bool {
if reflect.TypeOf(k1) != reflect.TypeOf(k2) { if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
return false return false

View File

@@ -52,6 +52,10 @@ var DefaultConfig *Config = &Config{
Acl: false, Acl: false,
} }
var (
log *zap.Logger
)
func showHelp() { func showHelp() {
fmt.Printf("%s\n", usageStr) fmt.Printf("%s\n", usageStr)
os.Exit(0) os.Exit(0)
@@ -105,7 +109,7 @@ func ConfigureConfig(args []string) (*Config, error) {
}) })
logger.InitLogger(config.Debug) logger.InitLogger(config.Debug)
brokerLog = logger.Get().Named("Broker") log = logger.Get().Named("Broker")
if configFile != "" { if configFile != "" {
tmpConfig, e := LoadConfig(configFile) tmpConfig, e := LoadConfig(configFile)
@@ -128,15 +132,15 @@ func LoadConfig(filename string) (*Config, error) {
content, err := ioutil.ReadFile(filename) content, err := ioutil.ReadFile(filename)
if err != nil { if err != nil {
brokerLog.Error("Read config file error: ", zap.Error(err)) log.Error("Read config file error: ", zap.Error(err))
return nil, err return nil, err
} }
// brokerLog.Info(string(content)) // log.Info(string(content))
var config Config var config Config
err = json.Unmarshal(content, &config) err = json.Unmarshal(content, &config)
if err != nil { if err != nil {
brokerLog.Error("Unmarshal config file error: ", zap.Error(err)) log.Error("Unmarshal config file error: ", zap.Error(err))
return nil, err return nil, err
} }
@@ -168,7 +172,7 @@ func (config *Config) check() error {
if config.TlsPort != "" { if config.TlsPort != "" {
if config.TlsInfo.CertFile == "" || config.TlsInfo.KeyFile == "" { if config.TlsInfo.CertFile == "" || config.TlsInfo.KeyFile == "" {
brokerLog.Error("tls config error, no cert or key file.") log.Error("tls config error, no cert or key file.")
return errors.New("tls config error, no cert or key file.") return errors.New("tls config error, no cert or key file.")
} }
if config.TlsHost == "" { if config.TlsHost == "" {

View File

@@ -6,10 +6,9 @@ import (
"fmt" "fmt"
"time" "time"
simplejson "github.com/bitly/go-simplejson"
"github.com/eclipse/paho.mqtt.golang/packets" "github.com/eclipse/paho.mqtt.golang/packets"
"go.uber.org/zap" "go.uber.org/zap"
simplejson "github.com/bitly/go-simplejson"
) )
func (c *client) SendInfo() { func (c *client) SendInfo() {
@@ -21,7 +20,7 @@ func (c *client) SendInfo() {
infoMsg := NewInfo(c.broker.id, url, false) infoMsg := NewInfo(c.broker.id, url, false)
err := c.WriterPacket(infoMsg) err := c.WriterPacket(infoMsg)
if err != nil { if err != nil {
brokerLog.Error("send info message error, ", zap.Error(err)) log.Error("send info message error, ", zap.Error(err))
return return
} }
} }
@@ -34,13 +33,11 @@ func (c *client) StartPing() {
case <-timeTicker.C: case <-timeTicker.C:
err := c.WriterPacket(ping) err := c.WriterPacket(ping)
if err != nil { if err != nil {
brokerLog.Error("ping error: ", zap.Error(err)) log.Error("ping error: ", zap.Error(err))
c.Close() c.Close()
} }
case _, ok := <-c.closed: case <-c.ctx.Done():
if !ok { return
return
}
} }
} }
} }
@@ -57,10 +54,10 @@ func (c *client) SendConnect() {
m.Keepalive = uint16(60) m.Keepalive = uint16(60)
err := c.WriterPacket(m) err := c.WriterPacket(m)
if err != nil { if err != nil {
brokerLog.Error("send connect message error, ", zap.Error(err)) log.Error("send connect message error, ", zap.Error(err))
return return
} }
brokerLog.Info("send connect success") log.Info("send connect success")
} }
func NewInfo(sid, url string, isforword bool) *packets.PublishPacket { func NewInfo(sid, url string, isforword bool) *packets.PublishPacket {
@@ -69,7 +66,7 @@ func NewInfo(sid, url string, isforword bool) *packets.PublishPacket {
pub.TopicName = BrokerInfoTopic pub.TopicName = BrokerInfoTopic
pub.Retain = false pub.Retain = false
info := fmt.Sprintf(`{"brokerID":"%s","brokerUrl":"%s"}`, sid, url) info := fmt.Sprintf(`{"brokerID":"%s","brokerUrl":"%s"}`, sid, url)
// brokerLog.Info("new info", string(info)) // log.Info("new info", string(info))
pub.Payload = []byte(info) pub.Payload = []byte(info)
return pub return pub
} }
@@ -81,17 +78,17 @@ func (c *client) ProcessInfo(packet *packets.PublishPacket) {
return return
} }
brokerLog.Info("recv remoteInfo: ", zap.String("payload", string(packet.Payload))) log.Info("recv remoteInfo: ", zap.String("payload", string(packet.Payload)))
js, err := simplejson.NewJson(packet.Payload) js, err := simplejson.NewJson(packet.Payload)
if err != nil { if err != nil {
brokerLog.Warn("parse info message err", zap.Error(err)) log.Warn("parse info message err", zap.Error(err))
return return
} }
routes, err := js.Get("data").Map() routes, err := js.Get("data").Map()
if routes == nil { if routes == nil {
brokerLog.Error("receive info message error, ", zap.Error(err)) log.Error("receive info message error, ", zap.Error(err))
return return
} }

View File

@@ -1,122 +0,0 @@
package broker
import (
"sync"
"github.com/eclipse/paho.mqtt.golang/packets"
)
type RetainList struct {
sync.RWMutex
root *rlevel
}
type rlevel struct {
nodes map[string]*rnode
}
type rnode struct {
next *rlevel
msg *packets.PublishPacket
}
type RetainResult struct {
msg []*packets.PublishPacket
}
func newRNode() *rnode {
return &rnode{}
}
func newRLevel() *rlevel {
return &rlevel{nodes: make(map[string]*rnode)}
}
func NewRetainList() *RetainList {
return &RetainList{root: newRLevel()}
}
func (r *RetainList) Insert(topic string, buf *packets.PublishPacket) error {
tokens, err := PublishTopicCheckAndSpilt(topic)
if err != nil {
return err
}
// brokerLog.Info("insert tokens:", tokens)
r.Lock()
l := r.root
var n *rnode
for _, t := range tokens {
n = l.nodes[t]
if n == nil {
n = newRNode()
l.nodes[t] = n
}
if n.next == nil {
n.next = newRLevel()
}
l = n.next
}
n.msg = buf
r.Unlock()
return nil
}
func (r *RetainList) Match(topic string) []*packets.PublishPacket {
tokens, err := SubscribeTopicCheckAndSpilt(topic)
if err != nil {
return nil
}
results := &RetainResult{}
r.Lock()
l := r.root
matchRLevel(l, tokens, results)
r.Unlock()
// brokerLog.Info("results: ", results)
return results.msg
}
func matchRLevel(l *rlevel, toks []string, results *RetainResult) {
var n *rnode
for i, t := range toks {
if l == nil {
return
}
// brokerLog.Info("l info :", l.nodes)
if t == "#" {
for _, n := range l.nodes {
n.GetAll(results)
}
}
if t == "+" {
for _, n := range l.nodes {
if len(t[i+1:]) == 0 {
results.msg = append(results.msg, n.msg)
} else {
matchRLevel(n.next, toks[i+1:], results)
}
}
}
n = l.nodes[t]
if n != nil {
l = n.next
} else {
l = nil
}
}
if n != nil {
results.msg = append(results.msg, n.msg)
}
}
func (r *rnode) GetAll(results *RetainResult) {
// brokerLog.Info("node 's message: ", string(r.msg))
if r.msg != nil {
results.msg = append(results.msg, r.msg)
}
l := r.next
for _, n := range l.nodes {
n.GetAll(results)
}
}

53
broker/sesson.go Normal file
View File

@@ -0,0 +1,53 @@
package broker
import "github.com/eclipse/paho.mqtt.golang/packets"
func (b *Broker) getSession(cli *client, req *packets.ConnectPacket, resp *packets.ConnackPacket) error {
// If CleanSession is set to 0, the server MUST resume communications with the
// client based on state from the current session, as identified by the client
// identifier. If there is no session associated with the client identifier the
// server must create a new session.
//
// If CleanSession is set to 1, the client and server must discard any previous
// session and start a new one. b session lasts as long as the network c
// onnection. State data associated with b session must not be reused in any
// subsequent session.
var err error
// Check to see if the client supplied an ID, if not, generate one and set
// clean session.
if len(req.ClientIdentifier) == 0 {
req.CleanSession = true
}
cid := req.ClientIdentifier
// If CleanSession is NOT set, check the session store for existing session.
// If found, return it.
if !req.CleanSession {
if cli.session, err = b.sessionMgr.Get(cid); err == nil {
resp.SessionPresent = true
if err := cli.session.Update(req); err != nil {
return err
}
}
}
// If CleanSession, or no existing session found, then create a new one
if cli.session == nil {
if cli.session, err = b.sessionMgr.New(cid); err != nil {
return err
}
resp.SessionPresent = false
if err := cli.session.Init(req); err != nil {
return err
}
}
return nil
}

View File

@@ -1,318 +0,0 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"errors"
"sync"
"go.uber.org/zap"
)
// A result structure better optimized for queue subs.
type SublistResult struct {
psubs []*subscription
qsubs []*subscription // don't make this a map, too expensive to iterate
}
// A Sublist stores and efficiently retrieves subscriptions.
type Sublist struct {
sync.RWMutex
cache map[string]*SublistResult
root *level
}
// A node contains subscriptions and a pointer to the next level.
type node struct {
next *level
psubs []*subscription
qsubs []*subscription
}
// A level represents a group of nodes and special pointers to
// wildcard nodes.
type level struct {
nodes map[string]*node
}
// Create a new default node.
func newNode() *node {
return &node{psubs: make([]*subscription, 0, 4), qsubs: make([]*subscription, 0, 4)}
}
// Create a new default level. We use FNV1A as the hash
// algortihm for the tokens, which should be short.
func newLevel() *level {
return &level{nodes: make(map[string]*node)}
}
// New will create a default sublist
func NewSublist() *Sublist {
return &Sublist{root: newLevel(), cache: make(map[string]*SublistResult)}
}
// Insert adds a subscription into the sublist
func (s *Sublist) Insert(sub *subscription) error {
tokens, err := SubscribeTopicCheckAndSpilt(sub.topic)
if err != nil {
return err
}
s.Lock()
l := s.root
var n *node
for _, t := range tokens {
n = l.nodes[t]
if n == nil {
n = newNode()
l.nodes[t] = n
}
if n.next == nil {
n.next = newLevel()
}
l = n.next
}
if sub.queue {
//check qsub is already exist
for i := range n.qsubs {
if equal(n.qsubs[i], sub) {
n.qsubs[i] = sub
return nil
}
}
n.qsubs = append(n.qsubs, sub)
} else {
//check psub is already exist
for i := range n.psubs {
if equal(n.psubs[i], sub) {
n.psubs[i] = sub
return nil
}
}
n.psubs = append(n.psubs, sub)
}
topic := string(sub.topic)
s.addToCache(topic, sub)
s.Unlock()
return nil
}
func (s *Sublist) addToCache(topic string, sub *subscription) {
for k, r := range s.cache {
if matchLiteral(k, topic) {
// Copy since others may have a reference.
nr := copyResult(r)
if sub.queue == false {
nr.psubs = append(nr.psubs, sub)
} else {
nr.qsubs = append(nr.qsubs, sub)
}
s.cache[k] = nr
}
}
}
func (s *Sublist) removeFromCache(topic string, sub *subscription) {
for k := range s.cache {
if !matchLiteral(k, topic) {
continue
}
// Since someone else may be referecing, can't modify the list
// safely, just let it re-populate.
delete(s.cache, k)
}
}
func matchLiteral(literal, topic string) bool {
tok, _ := SubscribeTopicCheckAndSpilt(topic)
li, _ := PublishTopicCheckAndSpilt(literal)
for i := 0; i < len(tok); i++ {
b := tok[i]
switch b {
case "+":
case "#":
return true
default:
if b != li[i] {
return false
}
}
}
return true
}
// Deep copy
func copyResult(r *SublistResult) *SublistResult {
nr := &SublistResult{}
nr.psubs = append([]*subscription(nil), r.psubs...)
nr.qsubs = append([]*subscription(nil), r.qsubs...)
return nr
}
func (s *Sublist) Remove(sub *subscription) error {
tokens, err := SubscribeTopicCheckAndSpilt(sub.topic)
if err != nil {
return err
}
s.Lock()
defer s.Unlock()
l := s.root
var n *node
for _, t := range tokens {
if l == nil {
return errors.New("No Matches subscription Found")
}
n = l.nodes[t]
if n != nil {
l = n.next
} else {
l = nil
}
}
if !s.removeFromNode(n, sub) {
return errors.New("No Matches subscription Found")
}
topic := string(sub.topic)
s.removeFromCache(topic, sub)
return nil
}
func (s *Sublist) removeFromNode(n *node, sub *subscription) (found bool) {
if n == nil {
return false
}
if sub.queue {
n.qsubs, found = removeSubFromList(sub, n.qsubs)
return found
} else {
n.psubs, found = removeSubFromList(sub, n.psubs)
return found
}
return false
}
func (s *Sublist) Match(topic string) *SublistResult {
s.RLock()
rc, ok := s.cache[topic]
s.RUnlock()
if ok {
return rc
}
tokens, err := PublishTopicCheckAndSpilt(topic)
if err != nil {
brokerLog.Error("\tserver/sublist.go: ", zap.Error(err))
return nil
}
result := &SublistResult{}
s.Lock()
l := s.root
if len(tokens) > 0 {
if tokens[0] == "/" {
if _, exist := l.nodes["#"]; exist {
addNodeToResults(l.nodes["#"], result)
}
if _, exist := l.nodes["+"]; exist {
matchLevel(l.nodes["/"].next, tokens[1:], result)
}
if _, exist := l.nodes["/"]; exist {
matchLevel(l.nodes["/"].next, tokens[1:], result)
}
} else {
matchLevel(s.root, tokens, result)
}
}
s.cache[topic] = result
if len(s.cache) > 1024 {
for k := range s.cache {
delete(s.cache, k)
break
}
}
s.Unlock()
return result
}
func matchLevel(l *level, toks []string, results *SublistResult) {
var swc, n *node
exist := false
for i, t := range toks {
if l == nil {
return
}
if _, exist = l.nodes["#"]; exist {
addNodeToResults(l.nodes["#"], results)
}
if t != "/" {
if swc, exist = l.nodes["+"]; exist {
matchLevel(l.nodes["+"].next, toks[i+1:], results)
}
} else {
if _, exist = l.nodes["+"]; exist {
addNodeToResults(l.nodes["+"], results)
}
}
n = l.nodes[t]
if n != nil {
l = n.next
} else {
l = nil
}
}
if n != nil {
addNodeToResults(n, results)
}
if swc != nil {
addNodeToResults(n, results)
}
}
// This will add in a node's results to the total results.
func addNodeToResults(n *node, results *SublistResult) {
results.psubs = append(results.psubs, n.psubs...)
results.qsubs = append(results.qsubs, n.qsubs...)
}
func removeSubFromList(sub *subscription, sl []*subscription) ([]*subscription, bool) {
for i := 0; i < len(sl); i++ {
if sl[i] == sub {
last := len(sl) - 1
sl[i] = sl[last]
sl[last] = nil
sl = sl[:last]
return shrinkAsNeeded(sl), true
}
}
return sl, false
}
// Checks if we need to do a resize. This is for very large growth then
// subsequent return to a more normal size from unsubscribe.
func shrinkAsNeeded(sl []*subscription) []*subscription {
lsl := len(sl)
csl := cap(sl)
// Don't bother if list not too big
if csl <= 8 {
return sl
}
pFree := float32(csl-lsl) / float32(csl)
if pFree > 0.50 {
return append([]*subscription(nil), sl...)
}
return sl
}

View File

@@ -0,0 +1,76 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package sessions
import (
"fmt"
"sync"
)
var _ SessionsProvider = (*memProvider)(nil)
func init() {
Register("mem", NewMemProvider())
}
type memProvider struct {
st map[string]*Session
mu sync.RWMutex
}
func NewMemProvider() *memProvider {
return &memProvider{
st: make(map[string]*Session),
}
}
func (this *memProvider) New(id string) (*Session, error) {
this.mu.Lock()
defer this.mu.Unlock()
this.st[id] = &Session{id: id}
return this.st[id], nil
}
func (this *memProvider) Get(id string) (*Session, error) {
this.mu.RLock()
defer this.mu.RUnlock()
sess, ok := this.st[id]
if !ok {
return nil, fmt.Errorf("store/Get: No session found for key %s", id)
}
return sess, nil
}
func (this *memProvider) Del(id string) {
this.mu.Lock()
defer this.mu.Unlock()
delete(this.st, id)
}
func (this *memProvider) Save(id string) error {
return nil
}
func (this *memProvider) Count() int {
return len(this.st)
}
func (this *memProvider) Close() error {
this.st = make(map[string]*Session)
return nil
}

View File

@@ -0,0 +1,95 @@
package sessions
import (
"time"
log "github.com/cihub/seelog"
"github.com/go-redis/redis"
jsoniter "github.com/json-iterator/go"
)
var redisClient *redis.Client
var _ SessionsProvider = (*redisProvider)(nil)
const (
sessionName = "session"
)
type redisProvider struct {
}
func init() {
Register("redis", NewRedisProvider())
}
func InitRedisConn(url string) {
redisClient = redis.NewClient(&redis.Options{
Addr: "127.0.0.1:6379",
Password: "", // no password set
DB: 0, // use default DB
})
err := redisClient.Ping().Err()
for err != nil {
log.Error("connect redis error: ", err, " 3s try again...")
time.Sleep(3 * time.Second)
err = redisClient.Ping().Err()
}
}
func NewRedisProvider() *redisProvider {
return &redisProvider{}
}
func (r *redisProvider) New(id string) (*Session, error) {
val, _ := jsoniter.Marshal(&Session{id: id})
err := redisClient.HSet(sessionName, id, val).Err()
if err != nil {
return nil, err
}
result, err := redisClient.HGet(sessionName, id).Bytes()
if err != nil {
return nil, err
}
sess := Session{}
err = jsoniter.Unmarshal(result, &sess)
if err != nil {
return nil, err
}
return &sess, nil
}
func (r *redisProvider) Get(id string) (*Session, error) {
result, err := redisClient.HGet(sessionName, id).Bytes()
if err != nil {
return nil, err
}
sess := Session{}
err = jsoniter.Unmarshal(result, &sess)
if err != nil {
return nil, err
}
return &sess, nil
}
func (r *redisProvider) Del(id string) {
redisClient.HDel(sessionName, id)
}
func (r *redisProvider) Save(id string) error {
return nil
}
func (r *redisProvider) Count() int {
return int(redisClient.HLen(sessionName).Val())
}
func (r *redisProvider) Close() error {
return redisClient.Del(sessionName).Err()
}

149
lib/sessions/session.go Normal file
View File

@@ -0,0 +1,149 @@
package sessions
import (
"fmt"
"sync"
"github.com/eclipse/paho.mqtt.golang/packets"
)
const (
// Queue size for the ack queue
defaultQueueSize = 16
)
type Session struct {
// cmsg is the CONNECT message
cmsg *packets.ConnectPacket
// Will message to publish if connect is closed unexpectedly
Will *packets.PublishPacket
// Retained publish message
Retained *packets.PublishPacket
// topics stores all the topis for this session/client
topics map[string]byte
// Initialized?
initted bool
// Serialize access to this session
mu sync.Mutex
id string
}
func (this *Session) Init(msg *packets.ConnectPacket) error {
this.mu.Lock()
defer this.mu.Unlock()
if this.initted {
return fmt.Errorf("Session already initialized")
}
this.cmsg = msg
if this.cmsg.WillFlag {
this.Will = packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
this.Will.Qos = this.cmsg.Qos
this.Will.TopicName = this.cmsg.WillTopic
this.Will.Payload = this.cmsg.WillMessage
this.Will.Retain = this.cmsg.WillRetain
}
this.topics = make(map[string]byte, 1)
this.id = string(msg.ClientIdentifier)
this.initted = true
return nil
}
func (this *Session) Update(msg *packets.ConnectPacket) error {
this.mu.Lock()
defer this.mu.Unlock()
this.cmsg = msg
return nil
}
func (this *Session) RetainMessage(msg *packets.PublishPacket) error {
this.mu.Lock()
defer this.mu.Unlock()
this.Retained = msg
return nil
}
func (this *Session) AddTopic(topic string, qos byte) error {
this.mu.Lock()
defer this.mu.Unlock()
if !this.initted {
return fmt.Errorf("Session not yet initialized")
}
this.topics[topic] = qos
return nil
}
func (this *Session) RemoveTopic(topic string) error {
this.mu.Lock()
defer this.mu.Unlock()
if !this.initted {
return fmt.Errorf("Session not yet initialized")
}
delete(this.topics, topic)
return nil
}
func (this *Session) Topics() ([]string, []byte, error) {
this.mu.Lock()
defer this.mu.Unlock()
if !this.initted {
return nil, nil, fmt.Errorf("Session not yet initialized")
}
var (
topics []string
qoss []byte
)
for k, v := range this.topics {
topics = append(topics, k)
qoss = append(qoss, v)
}
return topics, qoss, nil
}
func (this *Session) ID() string {
return this.cmsg.ClientIdentifier
}
func (this *Session) WillFlag() bool {
this.mu.Lock()
defer this.mu.Unlock()
return this.cmsg.WillFlag
}
func (this *Session) SetWillFlag(v bool) {
this.mu.Lock()
defer this.mu.Unlock()
this.cmsg.WillFlag = v
}
func (this *Session) CleanSession() bool {
this.mu.Lock()
defer this.mu.Unlock()
return this.cmsg.CleanSession
}

92
lib/sessions/sessions.go Normal file
View File

@@ -0,0 +1,92 @@
package sessions
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"io"
)
var (
ErrSessionsProviderNotFound = errors.New("Session: Session provider not found")
ErrKeyNotAvailable = errors.New("Session: not item found for key.")
providers = make(map[string]SessionsProvider)
)
type SessionsProvider interface {
New(id string) (*Session, error)
Get(id string) (*Session, error)
Del(id string)
Save(id string) error
Count() int
Close() error
}
// Register makes a session provider available by the provided name.
// If a Register is called twice with the same name or if the driver is nil,
// it panics.
func Register(name string, provider SessionsProvider) {
if provider == nil {
panic("session: Register provide is nil")
}
if _, dup := providers[name]; dup {
panic("session: Register called twice for provider " + name)
}
providers[name] = provider
}
func Unregister(name string) {
delete(providers, name)
}
type Manager struct {
p SessionsProvider
}
func NewManager(providerName string) (*Manager, error) {
p, ok := providers[providerName]
if !ok {
return nil, fmt.Errorf("session: unknown provider %q", providerName)
}
return &Manager{p: p}, nil
}
func (this *Manager) New(id string) (*Session, error) {
if id == "" {
id = this.sessionId()
}
return this.p.New(id)
}
func (this *Manager) Get(id string) (*Session, error) {
return this.p.Get(id)
}
func (this *Manager) Del(id string) {
this.p.Del(id)
}
func (this *Manager) Save(id string) error {
return this.p.Save(id)
}
func (this *Manager) Count() int {
return this.p.Count()
}
func (this *Manager) Close() error {
return this.p.Close()
}
func (manager *Manager) sessionId() string {
b := make([]byte, 15)
if _, err := io.ReadFull(rand.Reader, b); err != nil {
return ""
}
return base64.URLEncoding.EncodeToString(b)
}

549
lib/topics/memtopics.go Normal file
View File

@@ -0,0 +1,549 @@
package topics
import (
"fmt"
"reflect"
"sync"
"github.com/eclipse/paho.mqtt.golang/packets"
)
const (
QosAtMostOnce byte = iota
QosAtLeastOnce
QosExactlyOnce
QosFailure = 0x80
)
var _ TopicsProvider = (*memTopics)(nil)
type memTopics struct {
// Sub/unsub mutex
smu sync.RWMutex
// Subscription tree
sroot *snode
// Retained message mutex
rmu sync.RWMutex
// Retained messages topic tree
rroot *rnode
}
func init() {
Register("mem", NewMemProvider())
}
// NewMemProvider returns an new instance of the memTopics, which is implements the
// TopicsProvider interface. memProvider is a hidden struct that stores the topic
// subscriptions and retained messages in memory. The content is not persistend so
// when the server goes, everything will be gone. Use with care.
func NewMemProvider() *memTopics {
return &memTopics{
sroot: newSNode(),
rroot: newRNode(),
}
}
func ValidQos(qos byte) bool {
return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce
}
func (this *memTopics) Subscribe(topic []byte, qos byte, sub interface{}) (byte, error) {
if !ValidQos(qos) {
return QosFailure, fmt.Errorf("Invalid QoS %d", qos)
}
if sub == nil {
return QosFailure, fmt.Errorf("Subscriber cannot be nil")
}
this.smu.Lock()
defer this.smu.Unlock()
if qos > QosExactlyOnce {
qos = QosExactlyOnce
}
if err := this.sroot.sinsert(topic, qos, sub); err != nil {
return QosFailure, err
}
return qos, nil
}
func (this *memTopics) Unsubscribe(topic []byte, sub interface{}) error {
this.smu.Lock()
defer this.smu.Unlock()
return this.sroot.sremove(topic, sub)
}
// Returned values will be invalidated by the next Subscribers call
func (this *memTopics) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
if !ValidQos(qos) {
return fmt.Errorf("Invalid QoS %d", qos)
}
this.smu.RLock()
defer this.smu.RUnlock()
*subs = (*subs)[0:0]
*qoss = (*qoss)[0:0]
return this.sroot.smatch(topic, qos, subs, qoss)
}
func (this *memTopics) Retain(msg *packets.PublishPacket) error {
this.rmu.Lock()
defer this.rmu.Unlock()
// So apparently, at least according to the MQTT Conformance/Interoperability
// Testing, that a payload of 0 means delete the retain message.
// https://eclipse.org/paho/clients/testing/
if len(msg.Payload) == 0 {
return this.rroot.rremove([]byte(msg.TopicName))
}
return this.rroot.rinsert([]byte(msg.TopicName), msg)
}
func (this *memTopics) Retained(topic []byte, msgs *[]*packets.PublishPacket) error {
this.rmu.RLock()
defer this.rmu.RUnlock()
return this.rroot.rmatch(topic, msgs)
}
func (this *memTopics) Close() error {
this.sroot = nil
this.rroot = nil
return nil
}
// subscrition nodes
type snode struct {
// If this is the end of the topic string, then add subscribers here
subs []interface{}
qos []byte
// Otherwise add the next topic level here
snodes map[string]*snode
}
func newSNode() *snode {
return &snode{
snodes: make(map[string]*snode),
}
}
func (this *snode) sinsert(topic []byte, qos byte, sub interface{}) error {
// If there's no more topic levels, that means we are at the matching snode
// to insert the subscriber. So let's see if there's such subscriber,
// if so, update it. Otherwise insert it.
if len(topic) == 0 {
// Let's see if the subscriber is already on the list. If yes, update
// QoS and then return.
for i := range this.subs {
if equal(this.subs[i], sub) {
this.qos[i] = qos
return nil
}
}
// Otherwise add.
this.subs = append(this.subs, sub)
this.qos = append(this.qos, qos)
return nil
}
// Not the last level, so let's find or create the next level snode, and
// recursively call it's insert().
// ntl = next topic level
ntl, rem, err := nextTopicLevel(topic)
if err != nil {
return err
}
level := string(ntl)
// Add snode if it doesn't already exist
n, ok := this.snodes[level]
if !ok {
n = newSNode()
this.snodes[level] = n
}
return n.sinsert(rem, qos, sub)
}
// This remove implementation ignores the QoS, as long as the subscriber
// matches then it's removed
func (this *snode) sremove(topic []byte, sub interface{}) error {
// If the topic is empty, it means we are at the final matching snode. If so,
// let's find the matching subscribers and remove them.
if len(topic) == 0 {
// If subscriber == nil, then it's signal to remove ALL subscribers
if sub == nil {
this.subs = this.subs[0:0]
this.qos = this.qos[0:0]
return nil
}
// If we find the subscriber then remove it from the list. Technically
// we just overwrite the slot by shifting all other items up by one.
for i := range this.subs {
if equal(this.subs[i], sub) {
this.subs = append(this.subs[:i], this.subs[i+1:]...)
this.qos = append(this.qos[:i], this.qos[i+1:]...)
return nil
}
}
return fmt.Errorf("No topic found for subscriber")
}
// Not the last level, so let's find the next level snode, and recursively
// call it's remove().
// ntl = next topic level
ntl, rem, err := nextTopicLevel(topic)
if err != nil {
return err
}
level := string(ntl)
// Find the snode that matches the topic level
n, ok := this.snodes[level]
if !ok {
return fmt.Errorf("No topic found")
}
// Remove the subscriber from the next level snode
if err := n.sremove(rem, sub); err != nil {
return err
}
// If there are no more subscribers and snodes to the next level we just visited
// let's remove it
if len(n.subs) == 0 && len(n.snodes) == 0 {
delete(this.snodes, level)
}
return nil
}
// smatch() returns all the subscribers that are subscribed to the topic. Given a topic
// with no wildcards (publish topic), it returns a list of subscribers that subscribes
// to the topic. For each of the level names, it's a match
// - if there are subscribers to '#', then all the subscribers are added to result set
func (this *snode) smatch(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
// If the topic is empty, it means we are at the final matching snode. If so,
// let's find the subscribers that match the qos and append them to the list.
if len(topic) == 0 {
this.matchQos(qos, subs, qoss)
return nil
}
// ntl = next topic level
ntl, rem, err := nextTopicLevel(topic)
if err != nil {
return err
}
level := string(ntl)
for k, n := range this.snodes {
// If the key is "#", then these subscribers are added to the result set
if k == MWC {
n.matchQos(qos, subs, qoss)
} else if k == SWC || k == level {
if err := n.smatch(rem, qos, subs, qoss); err != nil {
return err
}
}
}
return nil
}
// retained message nodes
type rnode struct {
// If this is the end of the topic string, then add retained messages here
msg *packets.PublishPacket
// Otherwise add the next topic level here
rnodes map[string]*rnode
}
func newRNode() *rnode {
return &rnode{
rnodes: make(map[string]*rnode),
}
}
func (this *rnode) rinsert(topic []byte, msg *packets.PublishPacket) error {
// If there's no more topic levels, that means we are at the matching rnode.
if len(topic) == 0 {
// Reuse the message if possible
if this.msg == nil {
this.msg = msg
}
return nil
}
// Not the last level, so let's find or create the next level snode, and
// recursively call it's insert().
// ntl = next topic level
ntl, rem, err := nextTopicLevel(topic)
if err != nil {
return err
}
level := string(ntl)
// Add snode if it doesn't already exist
n, ok := this.rnodes[level]
if !ok {
n = newRNode()
this.rnodes[level] = n
}
return n.rinsert(rem, msg)
}
// Remove the retained message for the supplied topic
func (this *rnode) rremove(topic []byte) error {
// If the topic is empty, it means we are at the final matching rnode. If so,
// let's remove the buffer and message.
if len(topic) == 0 {
this.msg = nil
return nil
}
// Not the last level, so let's find the next level rnode, and recursively
// call it's remove().
// ntl = next topic level
ntl, rem, err := nextTopicLevel(topic)
if err != nil {
return err
}
level := string(ntl)
// Find the rnode that matches the topic level
n, ok := this.rnodes[level]
if !ok {
return fmt.Errorf("No topic found")
}
// Remove the subscriber from the next level rnode
if err := n.rremove(rem); err != nil {
return err
}
// If there are no more rnodes to the next level we just visited let's remove it
if len(n.rnodes) == 0 {
delete(this.rnodes, level)
}
return nil
}
// rmatch() finds the retained messages for the topic and qos provided. It's somewhat
// of a reverse match compare to match() since the supplied topic can contain
// wildcards, whereas the retained message topic is a full (no wildcard) topic.
func (this *rnode) rmatch(topic []byte, msgs *[]*packets.PublishPacket) error {
// If the topic is empty, it means we are at the final matching rnode. If so,
// add the retained msg to the list.
if len(topic) == 0 {
if this.msg != nil {
*msgs = append(*msgs, this.msg)
}
return nil
}
// ntl = next topic level
ntl, rem, err := nextTopicLevel(topic)
if err != nil {
return err
}
level := string(ntl)
if level == MWC {
// If '#', add all retained messages starting this node
this.allRetained(msgs)
} else if level == SWC {
// If '+', check all nodes at this level. Next levels must be matched.
for _, n := range this.rnodes {
if err := n.rmatch(rem, msgs); err != nil {
return err
}
}
} else {
// Otherwise, find the matching node, go to the next level
if n, ok := this.rnodes[level]; ok {
if err := n.rmatch(rem, msgs); err != nil {
return err
}
}
}
return nil
}
func (this *rnode) allRetained(msgs *[]*packets.PublishPacket) {
if this.msg != nil {
*msgs = append(*msgs, this.msg)
}
for _, n := range this.rnodes {
n.allRetained(msgs)
}
}
const (
stateCHR byte = iota // Regular character
stateMWC // Multi-level wildcard
stateSWC // Single-level wildcard
stateSEP // Topic level separator
stateSYS // System level topic ($)
)
// Returns topic level, remaining topic levels and any errors
func nextTopicLevel(topic []byte) ([]byte, []byte, error) {
s := stateCHR
for i, c := range topic {
switch c {
case '/':
if s == stateMWC {
return nil, nil, fmt.Errorf("Multi-level wildcard found in topic and it's not at the last level")
}
if i == 0 {
return []byte(SWC), topic[i+1:], nil
}
return topic[:i], topic[i+1:], nil
case '#':
if i != 0 {
return nil, nil, fmt.Errorf("Wildcard character '#' must occupy entire topic level")
}
s = stateMWC
case '+':
if i != 0 {
return nil, nil, fmt.Errorf("Wildcard character '+' must occupy entire topic level")
}
s = stateSWC
// case '$':
// if i == 0 {
// return nil, nil, fmt.Errorf("Cannot publish to $ topics")
// }
// s = stateSYS
default:
if s == stateMWC || s == stateSWC {
return nil, nil, fmt.Errorf("Wildcard characters '#' and '+' must occupy entire topic level")
}
s = stateCHR
}
}
// If we got here that means we didn't hit the separator along the way, so the
// topic is either empty, or does not contain a separator. Either way, we return
// the full topic
return topic, nil, nil
}
// The QoS of the payload messages sent in response to a subscription must be the
// minimum of the QoS of the originally published message (in this case, it's the
// qos parameter) and the maximum QoS granted by the server (in this case, it's
// the QoS in the topic tree).
//
// It's also possible that even if the topic matches, the subscriber is not included
// due to the QoS granted is lower than the published message QoS. For example,
// if the client is granted only QoS 0, and the publish message is QoS 1, then this
// client is not to be send the published message.
func (this *snode) matchQos(qos byte, subs *[]interface{}, qoss *[]byte) {
for _, sub := range this.subs {
// If the published QoS is higher than the subscriber QoS, then we skip the
// subscriber. Otherwise, add to the list.
// if qos >= this.qos[i] {
*subs = append(*subs, sub)
*qoss = append(*qoss, qos)
// }
}
}
func equal(k1, k2 interface{}) bool {
if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
return false
}
if reflect.ValueOf(k1).Kind() == reflect.Func {
return &k1 == &k2
}
if k1 == k2 {
return true
}
switch k1 := k1.(type) {
case string:
return k1 == k2.(string)
case int64:
return k1 == k2.(int64)
case int32:
return k1 == k2.(int32)
case int16:
return k1 == k2.(int16)
case int8:
return k1 == k2.(int8)
case int:
return k1 == k2.(int)
case float32:
return k1 == k2.(float32)
case float64:
return k1 == k2.(float64)
case uint:
return k1 == k2.(uint)
case uint8:
return k1 == k2.(uint8)
case uint16:
return k1 == k2.(uint16)
case uint32:
return k1 == k2.(uint32)
case uint64:
return k1 == k2.(uint64)
case uintptr:
return k1 == k2.(uintptr)
}
return false
}

91
lib/topics/topics.go Normal file
View File

@@ -0,0 +1,91 @@
package topics
import (
"fmt"
"github.com/eclipse/paho.mqtt.golang/packets"
)
const (
// MWC is the multi-level wildcard
MWC = "#"
// SWC is the single level wildcard
SWC = "+"
// SEP is the topic level separator
SEP = "/"
// SYS is the starting character of the system level topics
SYS = "$"
// Both wildcards
_WC = "#+"
)
var (
providers = make(map[string]TopicsProvider)
)
// TopicsProvider
type TopicsProvider interface {
Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error)
Unsubscribe(topic []byte, subscriber interface{}) error
Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error
Retain(msg *packets.PublishPacket) error
Retained(topic []byte, msgs *[]*packets.PublishPacket) error
Close() error
}
func Register(name string, provider TopicsProvider) {
if provider == nil {
panic("topics: Register provide is nil")
}
if _, dup := providers[name]; dup {
panic("topics: Register called twice for provider " + name)
}
providers[name] = provider
}
func Unregister(name string) {
delete(providers, name)
}
type Manager struct {
p TopicsProvider
}
func NewManager(providerName string) (*Manager, error) {
p, ok := providers[providerName]
if !ok {
return nil, fmt.Errorf("session: unknown provider %q", providerName)
}
return &Manager{p: p}, nil
}
func (this *Manager) Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error) {
return this.p.Subscribe(topic, qos, subscriber)
}
func (this *Manager) Unsubscribe(topic []byte, subscriber interface{}) error {
return this.p.Unsubscribe(topic, subscriber)
}
func (this *Manager) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
return this.p.Subscribers(topic, qos, subs, qoss)
}
func (this *Manager) Retain(msg *packets.PublishPacket) error {
return this.p.Retain(msg)
}
func (this *Manager) Retained(topic []byte, msgs *[]*packets.PublishPacket) error {
return this.p.Retained(topic, msgs)
}
func (this *Manager) Close() error {
return this.p.Close()
}

View File

@@ -6,7 +6,6 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/zap" "go.uber.org/zap"
) )