67 Commits
packet ... 1.3

Author SHA1 Message Date
joy.zhou
221d00480e update read.me 2018-01-26 16:29:14 +08:00
zhouyuyan
91733bf91e modify debug log 2018-01-26 15:47:34 +08:00
Marc Magnin
ef252550dc fhmq/hmq#5 added zap logger (#11) 2018-01-26 13:51:36 +08:00
joy.zhou
1058256235 update readme 2018-01-25 19:34:37 +08:00
joy.zhou
5a569f14a3 del debug info
delete debug message body
2018-01-25 19:31:47 +08:00
zhouyuyan
93b21777ff add lisence 2018-01-25 13:47:50 +08:00
zhouyuyan
dcf2934e1b add flag for hmq 2018-01-25 13:11:45 +08:00
joy.zhou
d9e6e216b0 Merge pull request #4 from MarcMagnin/master
fhmq/hmq#2 added full package ref
2018-01-24 18:14:13 +08:00
Marc Magnin
ca3951769a fhmq/hmq#2 added full package ref 2018-01-23 15:29:16 +01:00
zhouyuyan
0439e7ce90 fxi ws conn 2018-01-22 09:30:08 +08:00
zhouyuyan
dc0f2185ab skip self 2018-01-19 13:53:47 +08:00
zhouyuyan
7462afcfb5 modify readme 2018-01-19 13:49:53 +08:00
zhouyuyan
114e6f901e modify cluster 2018-01-19 13:41:17 +08:00
zhouyuyan
0cb51bd37a Merge branch 'master' of https://github.com/fhmq/hmq 2018-01-18 09:18:38 +08:00
zhouyuyan
819b4725f2 modify route 2018-01-18 09:17:48 +08:00
joy.zhou
85bdeccbfc release link
addd down link
2018-01-17 21:39:31 +08:00
zhouyuyan
1339a04b28 modify Dockerfile 2018-01-17 10:11:36 +08:00
zhouyuyan
957329d85c modify Dockerfile 2018-01-17 10:10:04 +08:00
zhouyuyan
7db7edaa17 cluster fix 2018-01-17 09:39:07 +08:00
zhouyuyan
1d6f6a4a71 add cluster 2018-01-16 16:50:10 +08:00
zhouyuyan
123bb7210f move dispatcher 2018-01-02 10:55:28 +08:00
zhouyuyan
9ad6590e83 modify timer 2017-12-28 09:13:20 +08:00
zhouyuyan
516db49db5 modify keep alive 2017-12-27 16:42:38 +08:00
zhouyuyan
a260057bfe modify time close 2017-12-08 13:25:05 +08:00
zhouyuyan
bdd802ebfb modify log 2017-12-07 16:30:48 +08:00
zhouyuyan
5786e69b01 modify cluster logic 2017-11-21 14:05:06 +08:00
zhouyuyan
6a89b627d4 add clientID in log for debug 2017-11-02 15:31:57 +08:00
zhouyuyan
208a7cf0a8 wait for message when close connection 2017-10-26 16:11:01 +08:00
zhouyuyan
a7fb7f1912 modify close old connect connection logic 2017-10-26 15:57:19 +08:00
zhouyuyan
eeab0c6b7d modify readloop 2017-09-22 12:08:11 +08:00
zhouyuyan
4646042b7f disconnect 2017-09-12 15:37:05 +08:00
zhouyuyan
49385e52fd log format 2017-09-12 09:50:54 +08:00
zhouyuyan
3ed8625bb9 log formar 2017-09-12 09:49:50 +08:00
zhouyuyan
6b50060eae free memory 2017-09-12 09:17:36 +08:00
zhouyuyan
96277996f0 packet 2017-09-12 09:15:08 +08:00
zhouyuyan
5601632a33 packet 2017-09-11 17:00:02 +08:00
zhouyuyan
cc1b3239ad remote logix fixed 2017-09-11 11:10:06 +08:00
zhouyuyan
476d22568b del statemonitor 2017-09-11 09:29:52 +08:00
chowyu08
c85ba76f8f 'status' 2017-09-09 16:04:40 +08:00
chowyu08
f3b2924b07 'addwss' 2017-09-09 15:37:00 +08:00
zhouyuyan
6144aeb6bf modifydockerfile 2017-09-08 16:40:13 +08:00
zhouyuyan
83b6934621 fix ws bug 2017-09-08 16:22:27 +08:00
zhouyuyan
7d6fcb7d65 fix ws buf 2017-09-08 16:04:39 +08:00
zhouyuyan
35390baa92 lb 2017-09-08 12:22:51 +08:00
zhouyuyan
f2efaa9992 lb 2017-09-08 10:44:53 +08:00
zhouyuyan
258912b33a lb 2017-09-08 10:44:01 +08:00
zhouyuyan
7f45bd6bc9 ls 2017-09-08 09:59:01 +08:00
chowyu08
7073e9b4ba 'queue' 2017-09-07 22:31:50 +08:00
zhouyuyan
8c98346546 dd 2017-09-07 17:03:33 +08:00
zhouyuyan
4300a32f6b queue lb 2017-09-07 15:31:38 +08:00
zhouyuyan
ae1af54c6e cluster sub 2017-09-07 14:32:01 +08:00
zhouyuyan
34378164e0 fixbug 2017-09-07 13:31:16 +08:00
chowyu08
47c44570fc 'fixtlsbug' 2017-09-06 20:55:54 +08:00
zhouyuyan
417a12174c force handshark 2017-09-06 09:18:43 +08:00
zhouyuyan
43a6bb8c5d modify remote 2017-09-05 13:52:53 +08:00
zhouyuyan
18d18738be close 2017-09-05 10:26:44 +08:00
zhouyuyan
8af790ffba broker 2017-09-04 16:57:01 +08:00
zhouyuyan
5e937601ce keepalive 2017-09-04 14:01:26 +08:00
zhouyuyan
31548b10e5 addreturn 2017-09-04 09:49:05 +08:00
zhouyuyan
ca7ebfb6e3 cluster sub 2017-09-01 22:17:22 +08:00
zhouyuyan
b98ae9ec6f qos 1 2017-09-01 19:53:02 +08:00
zhouyuyan
8bf6ccaa25 syncmap 2017-09-01 19:29:33 +08:00
zhouyuyan
65ac09cf50 mqp 2017-09-01 16:59:57 +08:00
zhouyuyan
d37100d059 del message 2017-09-01 14:00:53 +08:00
zhouyuyan
50a9a6841d byte 2017-09-01 13:57:40 +08:00
zhouyuyan
a45cccaa7a packet 2017-09-01 11:08:28 +08:00
chowyu08
c732d395e1 'packet' 2017-08-31 22:04:00 +08:00
62 changed files with 1108 additions and 6344 deletions

1
.gitignore vendored
View File

@@ -1,3 +1,4 @@
hmq
log
log/*
*.test

View File

@@ -1,11 +1,11 @@
FROM alpine
COPY hmq /
COPY broker.config /
COPY tls /tls
COPY ssl /ssl
COPY conf /conf
EXPOSE 1883
EXPOSE 1888
EXPOSE 8883
EXPOSE 1993
CMD ["/hmq"]

2
lib/message/LICENSE → LICENSE Executable file → Normal file
View File

@@ -1,4 +1,4 @@
Apache License
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/

View File

@@ -3,25 +3,69 @@ Free and High Performance MQTT Broker
## About
Golang MQTT Broker, Version 3.1.1, and Compatible
for [eclipse paho client](https://github.com/eclipse?utf8=%E2%9C%93&q=mqtt&type=&language=)
for [eclipse paho client](https://github.com/eclipse?utf8=%E2%9C%93&q=mqtt&type=&language=) and mosquitto-client
Download: [click here](https://github.com/fhmq/hmq/releases)
## RUNNING
```bash
$ git clone https://github.com/fhmq/hmq.git
$ cd hmq
$ go get github.com/fhmq/hmq
$ cd $GOPATH/github.com/fhmq/hmq
$ go run main.go
```
### broker.config
## Usage of hmq:
~~~
Usage of ./hmq:
-w int
worker num to process message, perfer (client num)/10. (default 1024)
-worker int
worker num to process message, perfer (client num)/10. (default 1024)
-h string
Network host to listen on. (default "0.0.0.0")
-host string
Network host to listen on. (default "0.0.0.0")
-p string
Port to listen on. (default "1883")
-port string
Port to listen on. (default "1883")
-c string
config file for hmq
-config string
config file for hmq
-cluster string
Cluster ip from which members can connect.
-cluster_listen string
Cluster ip from which members can connect.
-cluster_port string
Cluster port from which members can connect.
-cp string
Cluster port from which members can connect.
-r string
Router who maintenance cluster info
-router string
Router who maintenance cluster info
-ws_path string
path for ws to listen on
-ws_port string
port for ws to listen on
-wspath string
path for ws to listen on
-wsport string
port for ws to listen on
~~~
### hmq.config
~~~
{
"workerNum": 4096,
"port": "1883",
"host": "0.0.0.0",
"cluster": {
"host": "0.0.0.0",
"port": "1993",
"routers": ["10.10.0.11:1993","10.10.0.12:1993"]
"port": "1993"
},
"router": "127.0.0.1:9888",
"wsPort": "1888",
"wsPath": "/ws",
"wsTLS": true,
@@ -40,7 +84,7 @@ $ go run main.go
### Features and Future
* Supports QOS 0
* Supports QOS 0 and 1
* Cluster Support
@@ -58,6 +102,13 @@ $ go run main.go
* Flexible ACL
### Cluster
```bash
1, start router for hmq (https://github.com/fhmq/router.git)
2, config router in hmq.config ("router": "127.0.0.1:9888")
```
### QUEUE SUBSCRIBE
~~~
| Prefix | Examples |
@@ -128,4 +179,4 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | -->
## License
* Apache License Version 2.0
* Apache License Version 2.0

View File

@@ -1 +0,0 @@
theme: jekyll-theme-slate

View File

@@ -1,10 +1,13 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"hmq/lib/acl"
"strings"
log "github.com/cihub/seelog"
"github.com/fhmq/hmq/lib/acl"
"go.uber.org/zap"
"github.com/fsnotify/fsnotify"
)
@@ -14,7 +17,7 @@ const (
)
func (c *client) CheckTopicAuth(typ int, topic string) bool {
if !c.broker.config.Acl {
if c.typ != CLIENT || !c.broker.config.Acl {
return true
}
if strings.HasPrefix(topic, "$queue/") {
@@ -40,10 +43,10 @@ func (b *Broker) handleFsEvent(event fsnotify.Event) error {
case b.config.AclConf:
if event.Op&fsnotify.Write == fsnotify.Write ||
event.Op&fsnotify.Create == fsnotify.Create {
log.Info("text:handling acl config change event:", event)
log.Info("text:handling acl config change event:", zap.String("filename", event.Name))
aclconfig, err := acl.AclConfigLoad(event.Name)
if err != nil {
log.Error("aclconfig change failed, load acl conf error: ", err)
log.Error("aclconfig change failed, load acl conf error: ", zap.Error(err))
return err
}
b.AclConfig = aclconfig
@@ -56,14 +59,14 @@ func (b *Broker) StartAclWatcher() {
go func() {
wch, e := fsnotify.NewWatcher()
if e != nil {
log.Error("start monitor acl config file error,", e)
log.Error("start monitor acl config file error,", zap.Error(e))
return
}
defer wch.Close()
for _, i := range watchList {
if err := wch.Add(i); err != nil {
log.Error("start monitor acl config file error,", err)
log.Error("start monitor acl config file error,", zap.Error(err))
return
}
}
@@ -73,7 +76,7 @@ func (b *Broker) StartAclWatcher() {
case evt := <-wch.Events:
b.handleFsEvent(evt)
case err := <-wch.Errors:
log.Error("error:", err.Error())
log.Error("error:", zap.Error(err))
}
}
}()

View File

@@ -1,28 +1,42 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"crypto/tls"
"hmq/lib/acl"
"hmq/lib/message"
"net"
"net/http"
"runtime/debug"
"sync"
"sync/atomic"
"time"
"github.com/fhmq/hmq/lib/acl"
"github.com/eclipse/paho.mqtt.golang/packets"
"github.com/shirou/gopsutil/mem"
"go.uber.org/zap"
"golang.org/x/net/websocket"
log "github.com/cihub/seelog"
"github.com/fhmq/hmq/logger"
)
var (
log = logger.Get().Named("Broker")
)
type Broker struct {
id string
cid uint64
mu sync.Mutex
config *Config
tlsConfig *tls.Config
AclConfig *acl.ACLConfig
clients cMap
routes cMap
remotes cMap
clients sync.Map
routes sync.Map
remotes sync.Map
nodes map[string]interface{}
sl *Sublist
rl *RetainList
queues map[string]int
@@ -30,19 +44,17 @@ type Broker struct {
func NewBroker(config *Config) (*Broker, error) {
b := &Broker{
id: GenUniqueId(),
config: config,
sl: NewSublist(),
rl: NewRetainList(),
queues: make(map[string]int),
clients: NewClientMap(),
routes: NewClientMap(),
remotes: NewClientMap(),
id: GenUniqueId(),
config: config,
sl: NewSublist(),
rl: NewRetainList(),
nodes: make(map[string]interface{}),
queues: make(map[string]int),
}
if b.config.TlsPort != "" {
tlsconfig, err := NewTLSConfig(b.config.TlsInfo)
if err != nil {
log.Error("new tlsConfig error: ", err)
log.Error("new tlsConfig error", zap.Error(err))
return nil, err
}
b.tlsConfig = tlsconfig
@@ -50,7 +62,7 @@ func NewBroker(config *Config) (*Broker, error) {
if b.config.Acl {
aclconfig, err := acl.AclConfigLoad(b.config.AclConf)
if err != nil {
log.Error("Load acl conf error: ", err)
log.Error("Load acl conf error", zap.Error(err))
return nil, err
}
b.AclConfig = aclconfig
@@ -64,26 +76,62 @@ func (b *Broker) Start() {
log.Error("broker is null")
return
}
StartDispatcher()
//listen clinet over tcp
if b.config.Port != "" {
go b.StartListening(CLIENT)
go b.StartClientListening(false)
}
//listen for cluster
if b.config.Cluster.Port != "" {
go b.StartListening(ROUTER)
go b.StartClusterListening()
}
//listen for websocket
if b.config.WsPort != "" {
go b.StartWebsocketListening()
}
//listen client over tls
if b.config.TlsPort != "" {
go b.StartTLSListening()
go b.StartClientListening(true)
}
//connect on other node in cluster
if b.config.Router != "" {
b.ConnectToDiscovery()
}
//system monitor
go StateMonitor()
}
func StateMonitor() {
v, _ := mem.VirtualMemory()
timeSticker := time.NewTicker(time.Second * 30)
for {
select {
case <-timeSticker.C:
if v.UsedPercent > 75 {
debug.FreeOSMemory()
}
}
}
}
func (b *Broker) StartWebsocketListening() {
path := b.config.WsPath
hp := ":" + b.config.WsPort
log.Info("Start Webscoker Listening on ", hp, path)
log.Info("Start Websocket Listening on ", zap.String("hp", hp), zap.String("path", path))
http.Handle(path, websocket.Handler(b.wsHandler))
err := http.ListenAndServe(hp, nil)
var err error
if b.config.WsTLS {
err = http.ListenAndServeTLS(hp, b.config.TlsInfo.CertFile, b.config.TlsInfo.KeyFile, nil)
} else {
err = http.ListenAndServe(hp, nil)
}
if err != nil {
log.Error("ListenAndServe: " + err.Error())
return
@@ -91,17 +139,27 @@ func (b *Broker) StartWebsocketListening() {
}
func (b *Broker) wsHandler(ws *websocket.Conn) {
// io.Copy(ws, ws)
atomic.AddUint64(&b.cid, 1)
go b.handleConnection(CLIENT, ws, b.cid)
ws.PayloadType = websocket.BinaryFrame
b.handleConnection(CLIENT, ws, b.cid)
}
func (b *Broker) StartTLSListening() {
hp := b.config.TlsHost + ":" + b.config.TlsPort
log.Info("Start TLS Listening client on ", hp)
l, e := tls.Listen("tcp", hp, b.tlsConfig)
if e != nil {
log.Error("Error listening on ", e)
func (b *Broker) StartClientListening(Tls bool) {
var hp string
var err error
var l net.Listener
if Tls {
hp = b.config.TlsHost + ":" + b.config.TlsPort
l, err = tls.Listen("tcp", hp, b.tlsConfig)
log.Info("Start TLS Listening client on ", zap.String("hp", hp))
} else {
hp := b.config.Host + ":" + b.config.Port
l, err = net.Listen("tcp", hp)
log.Info("Start Listening client on ", zap.String("hp", hp))
}
if err != nil {
log.Error("Error listening on ", zap.Error(err))
return
}
tmpDelay := 10 * ACCEPT_MIN_SLEEP
@@ -110,99 +168,127 @@ func (b *Broker) StartTLSListening() {
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
ne, tmpDelay/time.Millisecond)
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
time.Sleep(tmpDelay)
tmpDelay *= 2
if tmpDelay > ACCEPT_MAX_SLEEP {
tmpDelay = ACCEPT_MAX_SLEEP
}
} else {
log.Error("Accept error: %v", err)
log.Error("Accept error: %v", zap.Error(err))
}
continue
}
tmpDelay = ACCEPT_MIN_SLEEP
atomic.AddUint64(&b.cid, 1)
go b.handleConnection(CLIENT, conn, b.cid)
}
}
func (b *Broker) StartListening(typ int) {
var hp string
if typ == CLIENT {
hp = b.config.Host + ":" + b.config.Port
log.Info("Start Listening client on ", hp)
} else if typ == ROUTER {
hp = b.config.Cluster.Host + ":" + b.config.Cluster.Port
log.Info("Start Listening cluster on ", hp)
func (b *Broker) Handshake(conn net.Conn) bool {
nc := tls.Server(conn, b.tlsConfig)
time.AfterFunc(DEFAULT_TLS_TIMEOUT, func() { TlsTimeout(nc) })
nc.SetReadDeadline(time.Now().Add(DEFAULT_TLS_TIMEOUT))
// Force handshake
if err := nc.Handshake(); err != nil {
log.Error("TLS handshake error, ", zap.Error(err))
return false
}
nc.SetReadDeadline(time.Time{})
return true
}
func TlsTimeout(conn *tls.Conn) {
nc := conn
// Check if already closed
if nc == nil {
return
}
cs := nc.ConnectionState()
if !cs.HandshakeComplete {
log.Error("TLS handshake timeout")
nc.Close()
}
}
func (b *Broker) StartClusterListening() {
var hp string = b.config.Cluster.Host + ":" + b.config.Cluster.Port
log.Info("Start Listening cluster on ", zap.String("hp", hp))
l, e := net.Listen("tcp", hp)
if e != nil {
log.Error("Error listening on ", e)
log.Error("Error listening on ", zap.Error(e))
return
}
var idx uint64 = 0
tmpDelay := 10 * ACCEPT_MIN_SLEEP
for {
conn, err := l.Accept()
if err != nil {
if ne, ok := err.(net.Error); ok && ne.Temporary() {
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
ne, tmpDelay/time.Millisecond)
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
time.Sleep(tmpDelay)
tmpDelay *= 2
if tmpDelay > ACCEPT_MAX_SLEEP {
tmpDelay = ACCEPT_MAX_SLEEP
}
} else {
log.Error("Accept error: %v", err)
log.Error("Accept error: %v", zap.Error(err))
}
continue
}
tmpDelay = ACCEPT_MIN_SLEEP
atomic.AddUint64(&b.cid, 1)
go b.handleConnection(typ, conn, b.cid)
go b.handleConnection(ROUTER, conn, idx)
}
}
func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
//process connect packet
buf, err := ReadPacket(conn)
packet, err := packets.ReadPacket(conn)
if err != nil {
log.Error("read connect packet error: ", err)
log.Error("read connect packet error: ", zap.Error(err))
return
}
connMsg, err := DecodeConnectMessage(buf)
if packet == nil {
log.Error("received nil packet")
return
}
msg, ok := packet.(*packets.ConnectPacket)
if !ok {
log.Error("received msg that was not Connect")
return
}
connack := packets.NewControlPacket(packets.Connack).(*packets.ConnackPacket)
connack.ReturnCode = packets.Accepted
connack.SessionPresent = msg.CleanSession
err = connack.Write(conn)
if err != nil {
log.Error(err)
log.Error("send connack error, ", zap.Error(err), zap.String("clientID", msg.ClientIdentifier))
return
}
connack := message.NewConnackMessage()
connack.SetReturnCode(message.ConnectionAccepted)
ack, _ := EncodeMessage(connack)
err1 := WriteBuffer(conn, ack)
if err1 != nil {
log.Error("send connack error, ", err1)
return
}
willmsg := message.NewPublishMessage()
if connMsg.WillFlag() {
willmsg.SetQoS(connMsg.WillQos())
willmsg.SetPayload(connMsg.WillMessage())
willmsg.SetRetain(connMsg.WillRetain())
willmsg.SetTopic(connMsg.WillTopic())
willmsg.SetDup(false)
willmsg := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
if msg.WillFlag {
willmsg.Qos = msg.WillQos
willmsg.TopicName = msg.WillTopic
willmsg.Retain = msg.WillRetain
willmsg.Payload = msg.WillMessage
willmsg.Dup = msg.Dup
} else {
willmsg = nil
}
info := info{
clientID: connMsg.ClientId(),
username: connMsg.Username(),
password: connMsg.Password(),
keepalive: connMsg.KeepAlive(),
clientID: msg.ClientIdentifier,
username: msg.Username,
password: msg.Password,
keepalive: msg.Keepalive,
willMsg: willmsg,
}
@@ -212,113 +298,247 @@ func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
conn: conn,
info: info,
}
c.init()
cid := c.info.clientID
var msgPool *MessagePool
var exist bool
var old *client
cid := string(c.info.clientID)
if typ == CLIENT {
old, exist = b.clients.Update(cid, c)
var old interface{}
switch typ {
case CLIENT:
msgPool = MSGPool[idx%MessagePoolNum].GetPool()
} else if typ == ROUTER {
old, exist = b.routes.Update(cid, c)
msgPool = MSGPool[MessagePoolNum].GetPool()
}
if exist {
log.Warn("client or routers exists, close old...")
old.Close()
}
c.readLoop(msgPool)
}
func (b *Broker) ConnectToRouters() {
for i := 0; i < len(b.config.Cluster.Routes); i++ {
url := b.config.Cluster.Routes[i]
go b.connectRouter(url, "")
}
}
func (b *Broker) connectRouter(url, remoteID string) {
for {
conn, err := net.Dial("tcp", url)
if err != nil {
log.Error("Error trying to connect to route: ", err)
select {
case <-time.After(DEFAULT_ROUTE_CONNECT):
log.Debug("Connect to route timeout ,retry...")
continue
c.mp = msgPool
old, exist = b.clients.Load(cid)
if exist {
log.Warn("client exist, close old...", zap.String("clientID", c.info.clientID))
ol, ok := old.(*client)
if ok {
msg := &Message{client: c, packet: DisconnectdPacket}
ol.mp.queue <- msg
}
}
route := &route{
remoteID: remoteID,
remoteUrl: url,
b.clients.Store(cid, c)
case ROUTER:
msgPool = MSGPool[(MessagePoolNum + idx)].GetPool()
c.mp = msgPool
old, exist = b.routes.Load(cid)
if exist {
log.Warn("router exist, close old...")
ol, ok := old.(*client)
if ok {
msg := &Message{client: c, packet: DisconnectdPacket}
ol.mp.queue <- msg
}
}
cid := GenUniqueId()
info := info{
clientID: []byte(cid),
}
c := &client{
typ: REMOTE,
conn: conn,
route: route,
info: info,
}
b.remotes.Set(cid, c)
c.SendConnect()
c.SendInfo()
// s.createRemote(conn, route)
msgPool := MSGPool[(MessagePoolNum + 1)].GetPool()
c.readLoop(msgPool)
b.routes.Store(cid, c)
}
c.readLoop()
}
func (b *Broker) ConnectToDiscovery() {
var conn net.Conn
var err error
var tempDelay time.Duration = 0
for {
conn, err = net.Dial("tcp", b.config.Router)
if err != nil {
log.Error("Error trying to connect to route: ", zap.Error(err))
log.Debug("Connect to route timeout ,retry...")
if 0 == tempDelay {
tempDelay = 1 * time.Second
} else {
tempDelay *= 2
}
if max := 20 * time.Second; tempDelay > max {
tempDelay = max
}
time.Sleep(tempDelay)
continue
}
break
}
log.Debug("connect to router success :", zap.String("Router", b.config.Router))
cid := b.id
info := info{
clientID: cid,
keepalive: 60,
}
c := &client{
typ: CLUSTER,
broker: b,
conn: conn,
info: info,
}
c.init()
c.SendConnect()
c.SendInfo()
c.mp = &MSGPool[(MessagePoolNum + 2)]
go c.readLoop()
go c.StartPing()
}
func (b *Broker) connectRouter(id, addr string) {
var conn net.Conn
var err error
var timeDelay time.Duration = 0
retryTimes := 0
max := 32 * time.Second
for {
if !b.checkNodeExist(id, addr) {
return
}
conn, err = net.Dial("tcp", addr)
if err != nil {
log.Error("Error trying to connect to route: ", zap.Error(err))
if retryTimes > 50 {
return
}
log.Debug("Connect to route timeout ,retry...")
if 0 == timeDelay {
timeDelay = 1 * time.Second
} else {
timeDelay *= 2
}
if timeDelay > max {
timeDelay = max
}
time.Sleep(timeDelay)
retryTimes++
continue
}
break
}
route := route{
remoteID: id,
remoteUrl: addr,
}
cid := GenUniqueId()
info := info{
clientID: cid,
keepalive: 60,
}
c := &client{
broker: b,
typ: REMOTE,
conn: conn,
route: route,
info: info,
}
c.init()
b.remotes.Store(cid, c)
c.mp = MSGPool[(MessagePoolNum + 1)].GetPool()
c.SendConnect()
// c.SendInfo()
go c.readLoop()
go c.StartPing()
}
func (b *Broker) checkNodeExist(id, url string) bool {
if id == b.id {
return false
}
for k, v := range b.nodes {
if k == id {
return true
}
//skip
l, ok := v.(string)
if ok {
if url == l {
return true
}
}
}
return false
}
func (b *Broker) CheckRemoteExist(remoteID, url string) bool {
exist := false
remotes := b.remotes.Items()
for _, v := range remotes {
if v.route.remoteUrl == url {
// if v.route.remoteID == "" || v.route.remoteID != remoteID {
v.route.remoteID = remoteID
// }
exist = true
break
b.remotes.Range(func(key, value interface{}) bool {
v, ok := value.(*client)
if ok {
if v.route.remoteUrl == url {
v.route.remoteID = remoteID
exist = true
return false
}
}
}
return true
})
return exist
}
func (b *Broker) SendLocalSubsToRouter(c *client) {
clients := b.clients.Items()
subMsg := message.NewSubscribeMessage()
for _, client := range clients {
subs := client.subs
for _, sub := range subs {
subMsg.AddTopic(sub.topic, sub.qos)
subInfo := packets.NewControlPacket(packets.Subscribe).(*packets.SubscribePacket)
b.clients.Range(func(key, value interface{}) bool {
client, ok := value.(*client)
if ok {
subs := client.subs
for _, sub := range subs {
subInfo.Topics = append(subInfo.Topics, string(sub.topic))
subInfo.Qoss = append(subInfo.Qoss, sub.qos)
}
}
return true
})
if len(subInfo.Topics) > 0 {
err := c.WriterPacket(subInfo)
if err != nil {
log.Error("Send localsubs To Router error :", zap.Error(err))
}
}
err := c.writeMessage(subMsg)
if err != nil {
log.Error("Send localsubs To Router error :", err)
}
}
func (b *Broker) BroadcastInfoMessage(remoteID string, msg message.Message) {
remotes := b.remotes.Items()
for _, r := range remotes {
if r.route.remoteID == remoteID {
continue
func (b *Broker) BroadcastInfoMessage(remoteID string, msg *packets.PublishPacket) {
b.routes.Range(func(key, value interface{}) bool {
r, ok := value.(*client)
if ok {
if r.route.remoteID == remoteID {
return true
}
r.WriterPacket(msg)
}
r.writeMessage(msg)
}
return true
})
// log.Info("BroadcastInfoMessage success ")
}
func (b *Broker) BroadcastSubOrUnsubMessage(buf []byte) {
remotes := b.remotes.Items()
for _, r := range remotes {
r.writeBuffer(buf)
}
func (b *Broker) BroadcastSubOrUnsubMessage(packet packets.ControlPacket) {
b.routes.Range(func(key, value interface{}) bool {
r, ok := value.(*client)
if ok {
r.WriterPacket(packet)
}
return true
})
// log.Info("BroadcastSubscribeMessage remotes: ", s.remotes)
}
@@ -327,48 +547,40 @@ func (b *Broker) removeClient(c *client) {
typ := c.typ
switch typ {
case CLIENT:
b.clients.Remove(clientId)
b.clients.Delete(clientId)
case ROUTER:
b.routes.Remove(clientId)
b.routes.Delete(clientId)
case REMOTE:
b.remotes.Remove(clientId)
b.remotes.Delete(clientId)
}
// log.Info("delete client ,", clientId)
}
func (b *Broker) ProcessPublishMessage(msg *message.PublishMessage) {
if b == nil {
return
}
topic := string(msg.Topic())
func (b *Broker) PublishMessage(packet *packets.PublishPacket) {
topic := packet.TopicName
r := b.sl.Match(topic)
// log.Info("psubs num: ", len(r.psubs))
if len(r.qsubs) == 0 && len(r.psubs) == 0 {
if len(r.psubs) == 0 {
return
}
for _, sub := range r.psubs {
if sub != nil {
err := sub.client.writeMessage(msg)
err := sub.client.WriterPacket(packet)
if err != nil {
log.Error("process message for psub error, ", err)
log.Error("process message for psub error, ", zap.Error(err))
}
}
}
for i, sub := range r.qsubs {
// s.qmu.Lock()
if cnt, exist := b.queues[string(sub.topic)]; exist && i == cnt {
if sub != nil {
err := sub.client.writeMessage(msg)
if err != nil {
log.Error("process will message for qsub error, ", err)
}
}
b.queues[topic] = (b.queues[topic] + 1) % len(r.qsubs)
break
}
// s.qmu.Unlock()
}
}
func (b *Broker) BroadcastUnSubscribe(subs map[string]*subscription) {
unsub := packets.NewControlPacket(packets.Unsubscribe).(*packets.UnsubscribePacket)
for topic, _ := range subs {
unsub.Topics = append(unsub.Topics, topic)
}
if len(unsub.Topics) > 0 {
b.BroadcastSubOrUnsubMessage(unsub)
}
}

View File

@@ -1,24 +1,31 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"errors"
"hmq/lib/message"
"net"
"strings"
"sync"
"time"
log "github.com/cihub/seelog"
"github.com/eclipse/paho.mqtt.golang/packets"
"go.uber.org/zap"
)
const (
// special pub topic for cluster info BrokerInfoTopic
BrokerInfoTopic = "broker001info/brokerinfo"
BrokerInfoTopic = "broker000100101info"
// CLIENT is an end user.
CLIENT = 0
// ROUTER is another router in the cluster.
ROUTER = 1
//REMOTE is the router connect to other cluster
REMOTE = 2
REMOTE = 2
CLUSTER = 3
)
const (
Connected = 1
Disconnected = 2
)
type client struct {
@@ -27,23 +34,33 @@ type client struct {
broker *Broker
conn net.Conn
info info
route *route
route route
status int
closed chan int
smu sync.RWMutex
mp *MessagePool
subs map[string]*subscription
rsubs map[string]*subInfo
}
type subInfo struct {
sub *subscription
num int
}
type subscription struct {
client *client
topic []byte
topic string
qos byte
queue bool
}
type info struct {
clientID []byte
username []byte
clientID string
username string
password []byte
keepalive uint16
willMsg *message.PublishMessage
willMsg *packets.PublishPacket
localIP string
remoteIP string
}
@@ -53,125 +70,171 @@ type route struct {
remoteUrl string
}
var (
DisconnectdPacket = packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket)
)
func (c *client) init() {
c.smu.Lock()
defer c.smu.Unlock()
c.status = Connected
c.closed = make(chan int, 1)
c.rsubs = make(map[string]*subInfo)
c.subs = make(map[string]*subscription, 10)
c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0]
c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0]
}
func (c *client) readLoop(msgPool *MessagePool) {
func (c *client) keepAlive(ch chan int) {
defer close(ch)
keepalive := time.Duration(c.info.keepalive*3/2) * time.Second
timer := time.NewTimer(keepalive)
msgPool := c.mp
for {
select {
case <-ch:
timer.Reset(keepalive)
case <-timer.C:
if c.typ == REMOTE || c.typ == CLUSTER {
timer.Reset(keepalive)
continue
}
log.Error("Client exceeded timeout, disconnecting. ", zap.String("ClientID", c.info.clientID), zap.Uint16("keepalive", c.info.keepalive))
msg := &Message{client: c, packet: DisconnectdPacket}
msgPool.queue <- msg
timer.Stop()
return
case _, ok := <-c.closed:
if !ok {
return
}
}
}
}
func (c *client) readLoop() {
nc := c.conn
msgPool := c.mp
if nc == nil || msgPool == nil {
return
}
msg := &Message{}
ch := make(chan int, 1000)
go c.keepAlive(ch)
for {
buf, err := ReadPacket(nc)
packet, err := packets.ReadPacket(nc)
if err != nil {
log.Error("read packet error: ", err)
c.Close()
return
log.Error("read packet error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
break
}
ch <- 1
msg := &Message{
client: c,
packet: packet,
}
msg.client = c
msg.msg = buf
msgPool.queue <- msg
}
msg := &Message{client: c, packet: DisconnectdPacket}
msgPool.queue <- msg
msgPool.Reduce()
}
func ProcessMessage(msg *Message) {
buf := msg.msg
c := msg.client
if c == nil || buf == nil {
ca := msg.packet
if ca == nil {
return
}
msgType := uint8(buf[0] & 0xF0 >> 4)
switch msgType {
case CONNACK:
// log.Info("Recv conack message..........")
c.ProcessConnAck(buf)
case CONNECT:
// log.Info("Recv connect message..........")
c.ProcessConnect(buf)
case PUBLISH:
// log.Info("Recv publish message..........")
c.ProcessPublish(buf)
case PUBACK:
//log.Info("Recv publish ack message..........")
c.ProcessPubAck(buf)
case PUBCOMP:
//log.Info("Recv publish ack message..........")
c.ProcessPubComp(buf)
case PUBREC:
//log.Info("Recv publish rec message..........")
c.ProcessPubREC(buf)
case PUBREL:
//log.Info("Recv publish rel message..........")
c.ProcessPubREL(buf)
case SUBSCRIBE:
// log.Info("Recv subscribe message.....")
c.ProcessSubscribe(buf)
case SUBACK:
// log.Info("Recv suback message.....")
case UNSUBSCRIBE:
// log.Info("Recv unsubscribe message.....")
c.ProcessUnSubscribe(buf)
case UNSUBACK:
//log.Info("Recv unsuback message.....")
case PINGREQ:
// log.Info("Recv PINGREQ message..........")
c.ProcessPing(buf)
case PINGRESP:
//log.Info("Recv PINGRESP message..........")
case DISCONNECT:
// log.Info("Recv DISCONNECT message.......")
log.Debug("Recv message from client,", zap.String("ClientID", c.info.clientID))
switch ca.(type) {
case *packets.ConnackPacket:
case *packets.ConnectPacket:
case *packets.PublishPacket:
packet := ca.(*packets.PublishPacket)
c.ProcessPublish(packet)
case *packets.PubackPacket:
case *packets.PubrecPacket:
case *packets.PubrelPacket:
case *packets.PubcompPacket:
case *packets.SubscribePacket:
packet := ca.(*packets.SubscribePacket)
c.ProcessSubscribe(packet)
case *packets.SubackPacket:
case *packets.UnsubscribePacket:
packet := ca.(*packets.UnsubscribePacket)
c.ProcessUnSubscribe(packet)
case *packets.UnsubackPacket:
case *packets.PingreqPacket:
c.ProcessPing()
case *packets.PingrespPacket:
case *packets.DisconnectPacket:
c.Close()
default:
log.Info("Recv Unknow message.......")
log.Info("Recv Unknow message.......", zap.String("ClientID", c.info.clientID))
}
}
func (c *client) ProcessConnect(buf []byte) {
}
func (c *client) ProcessConnAck(buf []byte) {
}
func (c *client) ProcessPublish(buf []byte) {
msg, err := DecodePublishMessage(buf)
if err != nil {
log.Error("Decode Publish Message error: ", err)
c.Close()
func (c *client) ProcessPublish(packet *packets.PublishPacket) {
if c.status == Disconnected {
return
}
topic := msg.Topic()
if c.typ != CLIENT || !c.CheckTopicAuth(PUB, string(topic)) {
topic := packet.TopicName
if topic == BrokerInfoTopic && c.typ == CLUSTER {
c.ProcessInfo(packet)
return
}
c.ProcessPublishMessage(buf, msg)
if msg.Retain() {
if !c.CheckTopicAuth(PUB, topic) {
log.Error("Pub Topics Auth failed, ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
return
}
switch packet.Qos {
case QosAtMostOnce:
c.ProcessPublishMessage(packet)
case QosAtLeastOnce:
puback := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
puback.MessageID = packet.MessageID
if err := c.WriterPacket(puback); err != nil {
log.Error("send puback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return
}
c.ProcessPublishMessage(packet)
case QosExactlyOnce:
return
default:
log.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID))
return
}
if packet.Retain {
if b := c.broker; b != nil {
err := b.rl.Insert(topic, buf)
err := b.rl.Insert(topic, packet)
if err != nil {
log.Error("Insert Retain Message error: ", err)
log.Error("Insert Retain Message error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
}
}
func (c *client) ProcessPublishMessage(buf []byte, msg *message.PublishMessage) {
func (c *client) ProcessPublishMessage(packet *packets.PublishPacket) {
if c.status == Disconnected {
return
}
b := c.broker
if b == nil {
return
}
typ := c.typ
topic := string(msg.Topic())
topic := packet.TopicName
r := b.sl.Match(topic)
// log.Info("psubs num: ", len(r.psubs))
@@ -181,220 +244,262 @@ func (c *client) ProcessPublishMessage(buf []byte, msg *message.PublishMessage)
for _, sub := range r.psubs {
if sub.client.typ == ROUTER {
if typ == ROUTER {
if typ != CLIENT {
continue
}
}
if sub != nil {
err := sub.client.writeBuffer(buf)
err := sub.client.WriterPacket(packet)
if err != nil {
log.Error("process message for psub error, ", err)
log.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
}
for i, sub := range r.qsubs {
if sub.client.typ == ROUTER {
if typ == ROUTER {
continue
}
}
// s.qmu.Lock()
if cnt, exist := b.queues[string(sub.topic)]; exist && i == cnt {
if sub != nil {
err := sub.client.writeBuffer(buf)
if err != nil {
log.Error("process will message for qsub error, ", err)
pre := -1
now := -1
t := "$queue/" + topic
cnt, exist := b.queues[t]
if exist {
// log.Info("queue index : ", cnt)
for _, sub := range r.qsubs {
if sub.client.typ == ROUTER {
if typ != CLIENT {
continue
}
}
b.queues[topic] = (b.queues[topic] + 1) % len(r.qsubs)
break
if c.typ == CLIENT {
now = now + 1
} else {
now = now + sub.client.rsubs[t].num
}
if cnt > pre && cnt <= now {
if sub != nil {
err := sub.client.WriterPacket(packet)
if err != nil {
log.Error("send publish error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
break
}
pre = now
}
// s.qmu.Unlock()
}
length := getQueueSubscribeNum(r.qsubs)
if length > 0 {
b.queues[t] = (b.queues[t] + 1) % length
}
}
func (c *client) ProcessPubAck(buf []byte) {
func getQueueSubscribeNum(qsubs []*subscription) int {
topic := "$queue/"
if len(qsubs) < 1 {
return 0
} else {
topic = topic + qsubs[0].topic
}
num := 0
for _, sub := range qsubs {
if sub.client.typ == CLIENT {
num = num + 1
} else {
num = num + sub.client.rsubs[topic].num
}
}
return num
}
func (c *client) ProcessPubREC(buf []byte) {
func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) {
if c.status == Disconnected {
return
}
}
func (c *client) ProcessPubREL(buf []byte) {
}
func (c *client) ProcessPubComp(buf []byte) {
}
func (c *client) ProcessSubscribe(buf []byte) {
b := c.broker
if b == nil {
return
}
msg, err := DecodeSubscribeMessage(buf)
if err != nil {
log.Error("Decode Subscribe Message error: ", err)
c.Close()
return
}
topics := msg.Topics()
qos := msg.Qos()
topics := packet.Topics
qoss := packet.Qoss
suback := message.NewSubackMessage()
suback.SetPacketId(msg.PacketId())
suback := packets.NewControlPacket(packets.Suback).(*packets.SubackPacket)
suback.MessageID = packet.MessageID
var retcodes []byte
for i, t := range topics {
topic := string(t)
for i, topic := range topics {
t := topic
//check topic auth for client
if c.typ == CLIENT {
if !c.CheckTopicAuth(SUB, topic) {
log.Error("CheckSubAuth failed")
retcodes = append(retcodes, message.QosFailure)
if !c.CheckTopicAuth(SUB, topic) {
log.Error("Sub topic Auth failed: ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
retcodes = append(retcodes, QosFailure)
continue
}
queue := strings.HasPrefix(topic, "$queue/")
if queue {
if len(t) > 7 {
t = t[7:]
if _, exists := b.queues[topic]; !exists {
b.queues[topic] = 0
}
} else {
retcodes = append(retcodes, QosFailure)
continue
}
}
if _, exist := c.subs[topic]; !exist {
queue := false
if strings.HasPrefix(topic, "$queue/") {
if len(t) > 7 {
t = t[7:]
queue = true
// b.qmu.Lock()
if _, exists := b.queues[topic]; !exists {
b.queues[topic] = 0
}
// b.qmu.Unlock()
} else {
retcodes = append(retcodes, message.QosFailure)
continue
}
}
sub := &subscription{
topic: t,
qos: qos[i],
client: c,
queue: queue,
}
c.mu.Lock()
c.subs[topic] = sub
c.mu.Unlock()
err := b.sl.Insert(sub)
if err != nil {
log.Error("Insert subscription error: ", err)
retcodes = append(retcodes, message.QosFailure)
}
retcodes = append(retcodes, qos[i])
} else {
//if exist ,check whether qos change
c.subs[topic].qos = qos[i]
retcodes = append(retcodes, qos[i])
sub := &subscription{
topic: t,
qos: qoss[i],
client: c,
queue: queue,
}
switch c.typ {
case CLIENT:
if _, exist := c.subs[topic]; !exist {
c.subs[topic] = sub
} else {
//if exist ,check whether qos change
c.subs[topic].qos = qoss[i]
retcodes = append(retcodes, qoss[i])
continue
}
case ROUTER:
if subinfo, exist := c.rsubs[topic]; !exist {
sinfo := &subInfo{sub: sub, num: 1}
c.rsubs[topic] = sinfo
} else {
subinfo.num = subinfo.num + 1
retcodes = append(retcodes, qoss[i])
continue
}
}
err := b.sl.Insert(sub)
if err != nil {
log.Error("Insert subscription error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
retcodes = append(retcodes, QosFailure)
} else {
retcodes = append(retcodes, qoss[i])
}
}
suback.ReturnCodes = retcodes
if err := suback.AddReturnCodes(retcodes); err != nil {
log.Error("add return suback code error, ", err)
// if typ == CLIENT {
c.Close()
// }
return
}
err1 := c.writeMessage(suback)
if err1 != nil {
log.Error("send suback error, ", err1)
err := c.WriterPacket(suback)
if err != nil {
log.Error("send suback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return
}
//broadcast subscribe message
if c.typ == CLIENT {
go b.BroadcastSubOrUnsubMessage(buf)
go b.BroadcastSubOrUnsubMessage(packet)
}
//process retain message
for _, t := range topics {
bufs := b.rl.Match(t)
for _, buf := range bufs {
log.Info("process retain message: ", string(buf))
if buf != nil && string(buf) != "" {
c.writeBuffer(buf)
packets := b.rl.Match(t)
for _, packet := range packets {
log.Info("process retain message: ", zap.Any("packet", packet), zap.String("ClientID", c.info.clientID))
if packet != nil {
c.WriterPacket(packet)
}
}
}
}
func (c *client) ProcessUnSubscribe(buf []byte) {
func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) {
if c.status == Disconnected {
return
}
b := c.broker
if b == nil {
return
}
unsub, err := DecodeUnsubscribeMessage(buf)
if err != nil {
log.Error("Decode UnSubscribe Message error: ", err)
c.Close()
return
}
topics := unsub.Topics()
typ := c.typ
topics := packet.Topics
for _, t := range topics {
var sub *subscription
ok := false
if sub, ok = c.subs[string(t)]; ok {
go c.unsubscribe(sub)
switch typ {
case CLIENT:
sub, ok := c.subs[t]
if ok {
c.unsubscribe(sub)
}
case ROUTER:
subinfo, ok := c.rsubs[t]
if ok {
subinfo.num = subinfo.num - 1
if subinfo.num < 1 {
delete(c.rsubs, t)
c.unsubscribe(subinfo.sub)
} else {
c.rsubs[t] = subinfo
}
}
}
}
resp := message.NewUnsubackMessage()
resp.SetPacketId(unsub.PacketId())
unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket)
unsuback.MessageID = packet.MessageID
err1 := c.writeMessage(resp)
if err1 != nil {
log.Error("send ubsuback error, ", err1)
err := c.WriterPacket(unsuback)
if err != nil {
log.Error("send unsuback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return
}
// //process ubsubscribe message
if c.typ == CLIENT {
b.BroadcastSubOrUnsubMessage(buf)
b.BroadcastSubOrUnsubMessage(packet)
}
}
func (c *client) unsubscribe(sub *subscription) {
c.mu.Lock()
delete(c.subs, string(sub.topic))
c.mu.Unlock()
if c.typ == CLIENT {
delete(c.subs, sub.topic)
if c.broker != nil {
c.broker.sl.Remove(sub)
}
b := c.broker
if b != nil && sub != nil {
b.sl.Remove(sub)
}
}
func (c *client) ProcessPing(buf []byte) {
_, err := DecodePingreqMessage(buf)
if err != nil {
log.Error("Decode PingRequest Message error: ", err)
c.Close()
func (c *client) ProcessPing() {
if c.status == Disconnected {
return
}
pingRspMsg := message.NewPingrespMessage()
err = c.writeMessage(pingRspMsg)
resp := packets.NewControlPacket(packets.Pingresp).(*packets.PingrespPacket)
err := c.WriterPacket(resp)
if err != nil {
log.Error("send PingResponse error, ", err)
log.Error("send PingResponse error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
return
}
}
func (c *client) Close() {
c.smu.Lock()
if c.status == Disconnected {
c.smu.Unlock()
return
}
//wait for message complete
time.Sleep(1 * time.Second)
c.status = Disconnected
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
c.smu.Unlock()
close(c.closed)
b := c.broker
subs := c.subs
if b != nil {
@@ -402,37 +507,34 @@ func (c *client) Close() {
for _, sub := range subs {
err := b.sl.Remove(sub)
if err != nil {
log.Error("closed client but remove sublist error, ", err)
log.Error("closed client but remove sublist error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
}
}
if c.info.willMsg != nil {
b.ProcessPublishMessage(c.info.willMsg)
if c.typ == CLIENT {
b.BroadcastUnSubscribe(subs)
}
if c.info.willMsg != nil {
b.PublishMessage(c.info.willMsg)
}
if c.typ == CLUSTER {
b.ConnectToDiscovery()
}
//do reconnect
if c.typ == REMOTE {
go b.connectRouter(c.route.remoteID, c.route.remoteUrl)
}
}
if c.conn != nil {
c.conn.Close()
c.conn = nil
}
}
func WriteBuffer(conn net.Conn, buf []byte) error {
if conn == nil {
return errors.New("conn is nul")
func (c *client) WriterPacket(packet packets.ControlPacket) error {
if packet == nil {
return nil
}
_, err := conn.Write(buf)
return err
}
func (c *client) writeBuffer(buf []byte) error {
c.mu.Lock()
err := WriteBuffer(c.conn, buf)
err := packet.Write(c.conn)
c.mu.Unlock()
return err
}
func (c *client) writeMessage(msg message.Message) error {
buf, err := EncodeMessage(msg)
if err != nil {
return err
}
return c.writeBuffer(buf)
}

View File

@@ -1,73 +0,0 @@
package broker
import "sync"
type cMap interface {
Set(key string, val *client)
Get(key string) (*client, bool)
Items() map[string]*client
Exist(key string) bool
Update(key string, val *client) (*client, bool)
Count() int
Remove(key string)
}
type clientMap struct {
items map[string]*client
mu sync.RWMutex
}
func NewClientMap() cMap {
smap := &clientMap{
items: make(map[string]*client),
}
return smap
}
func (s *clientMap) Set(key string, val *client) {
s.mu.Lock()
s.items[key] = val
s.mu.Unlock()
}
func (s *clientMap) Get(key string) (*client, bool) {
s.mu.RLock()
val, ok := s.items[key]
s.mu.RUnlock()
return val, ok
}
func (s *clientMap) Exist(key string) bool {
s.mu.RLock()
_, ok := s.items[key]
s.mu.RUnlock()
return ok
}
func (s *clientMap) Update(key string, val *client) (*client, bool) {
s.mu.Lock()
old, ok := s.items[key]
s.items[key] = val
s.mu.Unlock()
return old, ok
}
func (s *clientMap) Count() int {
s.mu.RLock()
len := len(s.items)
s.mu.RUnlock()
return len
}
func (s *clientMap) Remove(key string) {
s.mu.Lock()
delete(s.items, key)
s.mu.Unlock()
}
func (s *clientMap) Items() map[string]*client {
s.mu.RLock()
items := s.items
s.mu.RUnlock()
return items
}

View File

@@ -1,7 +1,8 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"bytes"
"crypto/md5"
"crypto/rand"
"encoding/base64"
@@ -40,17 +41,17 @@ const (
PINGRESP
DISCONNECT
)
const (
QosAtMostOnce byte = iota
QosAtLeastOnce
QosExactlyOnce
QosFailure = 0x80
)
func SubscribeTopicCheckAndSpilt(subject []byte) ([]string, error) {
topic := string(subject)
if bytes.IndexByte(subject, '#') != -1 {
if bytes.IndexByte(subject, '#') != len(subject)-1 {
return nil, errors.New("Topic format error with index of #")
}
func SubscribeTopicCheckAndSpilt(topic string) ([]string, error) {
if strings.Index(topic, "#") != -1 && strings.Index(topic, "#") != len(topic)-1 {
return nil, errors.New("Topic format error with index of #")
}
re := strings.Split(topic, "/")
for i, v := range re {
if i != 0 && i != (len(re)-1) {
@@ -70,11 +71,10 @@ func SubscribeTopicCheckAndSpilt(subject []byte) ([]string, error) {
}
func PublishTopicCheckAndSpilt(subject []byte) ([]string, error) {
if bytes.IndexByte(subject, '#') != -1 || bytes.IndexByte(subject, '+') != -1 {
func PublishTopicCheckAndSpilt(topic string) ([]string, error) {
if strings.Index(topic, "#") != -1 || strings.Index(topic, "+") != -1 {
return nil, errors.New("Publish Topic format error with + and #")
}
topic := string(subject)
re := strings.Split(topic, "/")
for i, v := range re {
if v == "" {

View File

@@ -1,23 +1,25 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"crypto/tls"
"crypto/x509"
"encoding/json"
"errors"
"flag"
"fmt"
"io/ioutil"
log "github.com/cihub/seelog"
)
const (
CONFIGFILE = "broker.config"
"go.uber.org/zap"
)
type Config struct {
Worker int `json:"workerNum"`
Host string `json:"host"`
Port string `json:"port"`
Cluster RouteInfo `json:"cluster"`
Router string `json:"router"`
TlsHost string `json:"tlsHost"`
TlsPort string `json:"tlsPort"`
WsPath string `json:"wsPath"`
@@ -29,9 +31,8 @@ type Config struct {
}
type RouteInfo struct {
Host string `json:"host"`
Port string `json:"port"`
Routes []string `json:"routes"`
Host string `json:"host"`
Port string `json:"port"`
}
type TLSInfo struct {
@@ -41,11 +42,60 @@ type TLSInfo struct {
KeyFile string `json:"keyFile"`
}
func LoadConfig() (*Config, error) {
var DefaultConfig *Config = &Config{
Worker: 4096,
Host: "0.0.0.0",
Port: "1883",
Acl: false,
}
content, err := ioutil.ReadFile(CONFIGFILE)
func ConfigureConfig() (*Config, error) {
config := &Config{}
var (
configFile string
)
flag.IntVar(&config.Worker, "w", 1024, "worker num to process message, perfer (client num)/10.")
flag.IntVar(&config.Worker, "worker", 1024, "worker num to process message, perfer (client num)/10.")
flag.StringVar(&config.Port, "port", "1883", "Port to listen on.")
flag.StringVar(&config.Port, "p", "1883", "Port to listen on.")
flag.StringVar(&config.Host, "host", "0.0.0.0", "Network host to listen on.")
flag.StringVar(&config.Host, "h", "0.0.0.0", "Network host to listen on.")
flag.StringVar(&config.Cluster.Host, "cluster", "", "Cluster ip from which members can connect.")
flag.StringVar(&config.Cluster.Host, "cluster_listen", "", "Cluster ip from which members can connect.")
flag.StringVar(&config.Cluster.Port, "cp", "", "Cluster port from which members can connect.")
flag.StringVar(&config.Cluster.Port, "cluster_port", "", "Cluster port from which members can connect.")
flag.StringVar(&config.Router, "r", "", "Router who maintenance cluster info")
flag.StringVar(&config.Router, "router", "", "Router who maintenance cluster info")
flag.StringVar(&config.WsPort, "wsport", "", "port for ws to listen on")
flag.StringVar(&config.WsPort, "ws_port", "", "port for ws to listen on")
flag.StringVar(&config.WsPath, "wspath", "", "path for ws to listen on")
flag.StringVar(&config.WsPath, "ws_path", "", "path for ws to listen on")
flag.StringVar(&configFile, "config", "", "config file for hmq")
flag.StringVar(&configFile, "c", "", "config file for hmq")
flag.Parse()
if configFile != "" {
tmpConfig, e := LoadConfig(configFile)
if e != nil {
return nil, e
} else {
config = tmpConfig
}
}
if err := config.check(); err != nil {
return nil, err
}
return config, nil
}
func LoadConfig(filename string) (*Config, error) {
content, err := ioutil.ReadFile(filename)
if err != nil {
log.Error("Read config file error: ", err)
log.Error("Read config file error: ", zap.Error(err))
return nil, err
}
// log.Info(string(content))
@@ -53,10 +103,21 @@ func LoadConfig() (*Config, error) {
var config Config
err = json.Unmarshal(content, &config)
if err != nil {
log.Error("Unmarshal config file error: ", err)
log.Error("Unmarshal config file error: ", zap.Error(err))
return nil, err
}
return &config, nil
}
func (config *Config) check() error {
if config.Worker == 0 {
config.Worker = 1024
}
WorkNum = config.Worker
if config.Port != "" {
if config.Host == "" {
config.Host = "0.0.0.0"
@@ -68,29 +129,33 @@ func LoadConfig() (*Config, error) {
config.Cluster.Host = "0.0.0.0"
}
}
if config.Router != "" {
if config.Cluster.Port == "" {
return errors.New("cluster port is null")
}
}
if config.TlsPort != "" {
if config.TlsInfo.CertFile == "" || config.TlsInfo.KeyFile == "" {
log.Error("tls config error, no cert or key file.")
return nil, err
return errors.New("tls config error, no cert or key file.")
}
if config.TlsHost == "" {
config.TlsHost = "0.0.0.0"
}
}
return &config, nil
return nil
}
func NewTLSConfig(tlsInfo TLSInfo) (*tls.Config, error) {
cert, err := tls.LoadX509KeyPair(tlsInfo.CertFile, tlsInfo.KeyFile)
if err != nil {
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err)
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", zap.Error(err))
}
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
if err != nil {
return nil, fmt.Errorf("error parsing certificate: %v", err)
return nil, fmt.Errorf("error parsing certificate: %v", zap.Error(err))
}
// Create TLSConfig

View File

@@ -1,14 +1,14 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
const (
WorkNum = 2048
)
var WorkNum int
type Dispatcher struct {
WorkerPool chan chan *Message
}
func init() {
func StartDispatcher() {
InitMessagePool()
dispatcher := NewDispatcher()
dispatcher.Run()
@@ -29,7 +29,7 @@ func NewDispatcher() *Dispatcher {
}
func (d *Dispatcher) dispatch() {
for i := 0; i < MessagePoolNum; i++ {
for i := 0; i < (MessagePoolNum + 3); i++ {
go func(idx int) {
for {
select {

View File

@@ -1,35 +1,45 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"fmt"
"hmq/lib/message"
"time"
"github.com/eclipse/paho.mqtt.golang/packets"
"go.uber.org/zap"
simplejson "github.com/bitly/go-simplejson"
log "github.com/cihub/seelog"
)
func (c *client) SendInfo() {
if c.status == Disconnected {
return
}
url := c.info.localIP + ":" + c.broker.config.Cluster.Port
infoMsg := NewInfo(c.broker.id, url, false)
err := c.writeMessage(infoMsg)
err := c.WriterPacket(infoMsg)
if err != nil {
log.Error("send info message error, ", err)
log.Error("send info message error, ", zap.Error(err))
return
}
// log.Info("send info success")
}
func (c *client) StartPing() {
timeTicker := time.NewTicker(time.Second * 30)
ping := message.NewPingreqMessage()
timeTicker := time.NewTicker(time.Second * 50)
ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket)
for {
select {
case <-timeTicker.C:
err := c.writeMessage(ping)
err := c.WriterPacket(ping)
if err != nil {
log.Error("ping error: ", err)
log.Error("ping error: ", zap.Error(err))
c.Close()
}
case _, ok := <-c.closed:
if !ok {
return
}
}
}
@@ -37,77 +47,70 @@ func (c *client) StartPing() {
func (c *client) SendConnect() {
clientID := c.info.clientID
connMsg := message.NewConnectMessage()
connMsg.SetClientId(clientID)
connMsg.SetVersion(0x04)
err := c.writeMessage(connMsg)
if err != nil {
log.Error("send connect message error, ", err)
if c.status != Connected {
return
}
// log.Info("send connet success")
m := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
m.CleanSession = true
m.ClientIdentifier = c.info.clientID
m.Keepalive = uint16(60)
err := c.WriterPacket(m)
if err != nil {
log.Error("send connect message error, ", zap.Error(err))
return
}
log.Info("send connect success")
}
func NewInfo(sid, url string, isforword bool) *message.PublishMessage {
infoMsg := message.NewPublishMessage()
infoMsg.SetTopic([]byte(BrokerInfoTopic))
info := fmt.Sprintf(`{"remoteID":"%s","url":"%s","isForward":%t}`, sid, url, isforword)
func NewInfo(sid, url string, isforword bool) *packets.PublishPacket {
pub := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
pub.Qos = 0
pub.TopicName = BrokerInfoTopic
pub.Retain = false
info := fmt.Sprintf(`{"brokerID":"%s","brokerUrl":"%s"}`, sid, url)
// log.Info("new info", string(info))
infoMsg.SetPayload([]byte(info))
infoMsg.SetQoS(0)
infoMsg.SetRetain(false)
return infoMsg
pub.Payload = []byte(info)
return pub
}
func (c *client) ProcessInfo(msg *message.PublishMessage) {
func (c *client) ProcessInfo(packet *packets.PublishPacket) {
nc := c.conn
b := c.broker
if nc == nil {
return
}
log.Info("recv remoteInfo: ", string(msg.Payload()))
log.Info("recv remoteInfo: ", zap.String("payload", string(packet.Payload)))
js, e := simplejson.NewJson(msg.Payload())
if e != nil {
log.Warn("parse info message err", e)
js, err := simplejson.NewJson(packet.Payload)
if err != nil {
log.Warn("parse info message err", zap.Error(err))
return
}
rid := js.Get("remoteID").MustString()
rurl := js.Get("url").MustString()
isForward := js.Get("isForward").MustBool()
if rid == "" {
log.Error("receive info message error with remoteID is null")
routes, err := js.Get("data").Map()
if routes == nil {
log.Error("receive info message error, ", zap.Error(err))
return
}
if rid == b.id {
if !isForward {
c.Close() //close connet self
b.nodes = routes
b.mu.Lock()
for rid, rurl := range routes {
if rid == b.id {
continue
}
return
}
exist := b.CheckRemoteExist(rid, rurl)
if !exist {
go b.connectRouter(rurl, rid)
}
// log.Info("isforword: ", isForward)
if !isForward {
route := &route{
remoteUrl: rurl,
remoteID: rid,
url, ok := rurl.(string)
if ok {
exist := b.CheckRemoteExist(rid, url)
if !exist {
b.connectRouter(rid, url)
}
}
c.route = route
go b.SendLocalSubsToRouter(c)
// log.Info("BroadcastInfoMessage starting... ")
infoMsg := NewInfo(rid, rurl, true)
b.BroadcastInfoMessage(rid, infoMsg)
}
return
b.mu.Unlock()
}

View File

@@ -1,6 +1,12 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import "sync"
import (
"sync"
"github.com/eclipse/paho.mqtt.golang/packets"
)
const (
MaxUser = 1024 * 1024
@@ -11,7 +17,7 @@ const (
type Message struct {
client *client
msg []byte
packet packets.ControlPacket
}
var (
@@ -26,8 +32,8 @@ type MessagePool struct {
}
func InitMessagePool() {
MSGPool = make([]MessagePool, (MessagePoolNum + 2))
for i := 0; i < (MessagePoolNum + 2); i++ {
MSGPool = make([]MessagePool, (MessagePoolNum + 3))
for i := 0; i < (MessagePoolNum + 3); i++ {
MSGPool[i].Init(MessagePoolUser, MessagePoolMessageNum)
}
}

View File

@@ -1,236 +0,0 @@
package broker
import (
"errors"
"hmq/lib/message"
"io"
"net"
log "github.com/cihub/seelog"
)
func checkError(desc string, err error) {
if err != nil {
log.Error(desc, " : ", err)
}
}
func ReadPacket(conn net.Conn) ([]byte, error) {
if conn == nil {
return nil, errors.New("conn is null")
}
// conn.SetReadDeadline(t)
var buf []byte
// read fix header
b := make([]byte, 1)
_, err := io.ReadFull(conn, b)
if err != nil {
return nil, err
}
buf = append(buf, b...)
// read rem msg length
rembuf, remlen := decodeLength(conn)
buf = append(buf, rembuf...)
// read rem msg
packetBytes := make([]byte, remlen)
_, err = io.ReadFull(conn, packetBytes)
if err != nil {
return nil, err
}
buf = append(buf, packetBytes...)
// log.Info("len buf: ", len(buf))
return buf, nil
}
func decodeLength(r io.Reader) ([]byte, int) {
var rLength uint32
var multiplier uint32
var buf []byte
b := make([]byte, 1)
for {
io.ReadFull(r, b)
digit := b[0]
buf = append(buf, b[0])
rLength |= uint32(digit&127) << multiplier
if (digit & 128) == 0 {
break
}
multiplier += 7
}
return buf, int(rLength)
}
func DecodeMessage(buf []byte) (message.Message, error) {
msgType := uint8(buf[0] & 0xF0 >> 4)
switch msgType {
case CONNECT:
return DecodeConnectMessage(buf)
case CONNACK:
return DecodeConnackMessage(buf)
case PUBLISH:
return DecodePublishMessage(buf)
case PUBACK:
return DecodePubackMessage(buf)
case PUBCOMP:
return DecodePubcompMessage(buf)
case PUBREC:
return DecodePubrecMessage(buf)
case PUBREL:
return DecodePubrelMessage(buf)
case SUBSCRIBE:
return DecodeSubscribeMessage(buf)
case SUBACK:
return DecodeSubackMessage(buf)
case UNSUBSCRIBE:
return DecodeUnsubscribeMessage(buf)
case UNSUBACK:
return DecodeUnsubackMessage(buf)
case PINGREQ:
return DecodePingreqMessage(buf)
case PINGRESP:
return DecodePingrespMessage(buf)
case DISCONNECT:
return DecodeDisconnectMessage(buf)
default:
return nil, errors.New("error message type")
}
}
func DecodeConnectMessage(buf []byte) (*message.ConnectMessage, error) {
connMsg := message.NewConnectMessage()
_, err := connMsg.Decode(buf)
if err != nil {
if !message.ValidConnackError(err) {
return nil, errors.New("Connect message format error, " + err.Error())
}
return nil, errors.New("Deode connect message error, " + err.Error())
}
return connMsg, nil
}
func DecodeConnackMessage(buf []byte) (*message.ConnackMessage, error) {
msg := message.NewConnackMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Connack message error, " + err.Error())
}
return msg, nil
}
func DecodePublishMessage(buf []byte) (*message.PublishMessage, error) {
msg := message.NewPublishMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Publish message error, " + err.Error())
}
return msg, nil
}
func DecodePubackMessage(buf []byte) (*message.PubackMessage, error) {
msg := message.NewPubackMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Puback message error, " + err.Error())
}
return msg, nil
}
func DecodePubrecMessage(buf []byte) (*message.PubrecMessage, error) {
msg := message.NewPubrecMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Pubrec message error, " + err.Error())
}
return msg, nil
}
func DecodePubrelMessage(buf []byte) (*message.PubrelMessage, error) {
msg := message.NewPubrelMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Pubrel message error, " + err.Error())
}
return msg, nil
}
func DecodePubcompMessage(buf []byte) (*message.PubcompMessage, error) {
msg := message.NewPubcompMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Pubcomp message error, " + err.Error())
}
return msg, nil
}
func DecodeSubscribeMessage(buf []byte) (*message.SubscribeMessage, error) {
msg := message.NewSubscribeMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Subscribe message error, " + err.Error())
}
return msg, nil
}
func DecodeSubackMessage(buf []byte) (*message.SubackMessage, error) {
msg := message.NewSubackMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Suback message error, " + err.Error())
}
return msg, nil
}
func DecodeUnsubscribeMessage(buf []byte) (*message.UnsubscribeMessage, error) {
msg := message.NewUnsubscribeMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Unsubscribe message error, " + err.Error())
}
return msg, nil
}
func DecodeUnsubackMessage(buf []byte) (*message.UnsubackMessage, error) {
msg := message.NewUnsubackMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Unsuback message error, " + err.Error())
}
return msg, nil
}
func DecodePingreqMessage(buf []byte) (*message.PingreqMessage, error) {
msg := message.NewPingreqMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Pingreq message error, " + err.Error())
}
return msg, nil
}
func DecodePingrespMessage(buf []byte) (*message.PingrespMessage, error) {
msg := message.NewPingrespMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Pingresp message error, " + err.Error())
}
return msg, nil
}
func DecodeDisconnectMessage(buf []byte) (*message.DisconnectMessage, error) {
msg := message.NewDisconnectMessage()
_, err := msg.Decode(buf)
if err != nil {
return nil, errors.New("Decode Disconnect message error, " + err.Error())
}
return msg, nil
}
func EncodeMessage(msg message.Message) ([]byte, error) {
buf := make([]byte, msg.Len())
_, err := msg.Encode(buf)
if err != nil {
return nil, err
}
return buf, nil
}

View File

@@ -2,6 +2,8 @@ package broker
import (
"sync"
"github.com/eclipse/paho.mqtt.golang/packets"
)
type RetainList struct {
@@ -13,14 +15,14 @@ type rlevel struct {
}
type rnode struct {
next *rlevel
msg []byte
msg *packets.PublishPacket
}
type RetainResult struct {
msg [][]byte
msg []*packets.PublishPacket
}
func newRNode() *rnode {
return &rnode{msg: make([]byte, 0, 4)}
return &rnode{}
}
func newRLevel() *rlevel {
@@ -31,7 +33,7 @@ func NewRetainList() *RetainList {
return &RetainList{root: newRLevel()}
}
func (r *RetainList) Insert(topic, buf []byte) error {
func (r *RetainList) Insert(topic string, buf *packets.PublishPacket) error {
tokens, err := PublishTopicCheckAndSpilt(topic)
if err != nil {
@@ -58,7 +60,7 @@ func (r *RetainList) Insert(topic, buf []byte) error {
return nil
}
func (r *RetainList) Match(topic []byte) [][]byte {
func (r *RetainList) Match(topic string) []*packets.PublishPacket {
tokens, err := SubscribeTopicCheckAndSpilt(topic)
if err != nil {
@@ -110,7 +112,7 @@ func matchRLevel(l *rlevel, toks []string, results *RetainResult) {
func (r *rnode) GetAll(results *RetainResult) {
// log.Info("node 's message: ", string(r.msg))
if r.msg != nil && string(r.msg) != "" {
if r.msg != nil {
results.msg = append(results.msg, r.msg)
}
l := r.next

View File

@@ -1,10 +1,12 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
import (
"errors"
"sync"
log "github.com/cihub/seelog"
"go.uber.org/zap"
)
// A result structure better optimized for queue subs.
@@ -124,8 +126,8 @@ func (s *Sublist) removeFromCache(topic string, sub *subscription) {
}
func matchLiteral(literal, topic string) bool {
tok, _ := SubscribeTopicCheckAndSpilt([]byte(topic))
li, _ := PublishTopicCheckAndSpilt([]byte(literal))
tok, _ := SubscribeTopicCheckAndSpilt(topic)
li, _ := PublishTopicCheckAndSpilt(literal)
for i := 0; i < len(tok); i++ {
b := tok[i]
@@ -207,9 +209,9 @@ func (s *Sublist) Match(topic string) *SublistResult {
return rc
}
tokens, err := PublishTopicCheckAndSpilt([]byte(topic))
tokens, err := PublishTopicCheckAndSpilt(topic)
if err != nil {
log.Error("\tserver/sublist.go: ", err)
log.Error("\tserver/sublist.go: ", zap.Error(err))
return nil
}
@@ -241,7 +243,6 @@ func (s *Sublist) Match(topic string) *SublistResult {
}
s.Unlock()
// log.Info("SublistResult: ", result)
return result
}
@@ -294,7 +295,6 @@ func removeSubFromList(sub *subscription, sl []*subscription) ([]*subscription,
sl[i] = sl[last]
sl[last] = nil
sl = sl[:last]
// log.Info("removeSubFromList success")
return shrinkAsNeeded(sl), true
}
}

View File

@@ -1,3 +1,5 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package broker
type Worker struct {

View File

@@ -1,11 +1,12 @@
{
"workerNum": 4096,
"port": "1883",
"host": "0.0.0.0",
"cluster": {
"host": "0.0.0.0",
"port": "1993",
"routes": []
"port": "1993"
},
"router": "127.0.0.1:9888",
"tlsPort": "8883",
"tlsHost": "0.0.0.0",
"wsPort": "1888",
@@ -17,6 +18,6 @@
"certFile": "ssl/server/cert.pem",
"keyFile": "ssl/server/key.pem"
},
"acl": true,
"acl": false,
"aclConf": "conf/acl.conf"
}
}

BIN
hmq

Binary file not shown.

View File

@@ -1,3 +1,5 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package acl
import (

View File

@@ -1,3 +1,4 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>*/
package acl
import "strings"

View File

@@ -1,3 +1,5 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package acl
import (

View File

@@ -1,16 +0,0 @@
# This is the official list of SurgeMQ authors for copyright purposes.
# If you are submitting a patch, please add your name or the name of the
# organization which holds the copyright to this list in alphabetical order.
# Names should be added to this file as
# Name <email address>
# The email address is not required for organizations.
# Please keep the list sorted.
# Individual Persons
Jian Zhen <zhenjl@gmail.com>
# Organizations

View File

@@ -1,36 +0,0 @@
# Contributing Guidelines
## Reporting Issues
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/surgemq/message/issues?state=open) or was [recently closed](https://github.com/surgemq/message/issues?direction=desc&page=1&sort=updated&state=closed).
Please provide the following minimum information:
* Your SurgeMQ version (or git SHA)
* Your Go version (run `go version` in your console)
* A detailed issue description
* Error Log if present
* If possible, a short example
## Contributing Code
By contributing to this project, you share your code under the Apache License, Version 2.0, as specified in the LICENSE file.
Don't forget to add yourself to the AUTHORS file.
### Pull Requests Checklist
Please check the following points before submitting your pull request:
- [x] Code compiles correctly
- [x] Created tests, if possible
- [x] All tests pass
- [x] Extended the README / documentation, if necessary
- [x] Added yourself to the AUTHORS file
### Code Review
Everyone is invited to review and comment on pull requests.
If it looks fine to you, comment with "LGTM" (Looks good to me).
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
Before merging the Pull Request, at least one [team member](https://github.com/orgs/surgemq/people) must have commented with "LGTM".

View File

@@ -1,136 +0,0 @@
Package message is an encoder/decoder library for MQTT 3.1 and 3.1.1 messages. You can
find the MQTT specs at the following locations:
> 3.1.1 - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/
> 3.1 - http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
From the spec:
> MQTT is a Client Server publish/subscribe messaging transport protocol. It is
> light weight, open, simple, and designed so as to be easy to implement. These
> characteristics make it ideal for use in many situations, including constrained
> environments such as for communication in Machine to Machine (M2M) and Internet
> of Things (IoT) contexts where a small code footprint is required and/or network
> bandwidth is at a premium.
>
> The MQTT protocol works by exchanging a series of MQTT messages in a defined way.
> The protocol runs over TCP/IP, or over other network protocols that provide
> ordered, lossless, bi-directional connections.
There are two main items to take note in this package. The first is
```
type MessageType byte
```
MessageType is the type representing the MQTT packet types. In the MQTT spec, MQTT
control packet type is represented as a 4-bit unsigned value. MessageType receives
several methods that returns string representations of the names and descriptions.
Also, one of the methods is New(). It returns a new Message object based on the mtype
parameter. For example:
```
m, err := CONNECT.New()
msg := m.(*ConnectMessage)
```
This would return a PublishMessage struct, but mapped to the Message interface. You can
then type assert it back to a *PublishMessage. Another way to create a new
PublishMessage is to call
```
msg := NewConnectMessage()
```
Every message type has a New function that returns a new message. The list of available
message types are defined as constants below.
As you may have noticed, the second important item is the Message interface. It defines
several methods that are common to all messages, including Name(), Desc(), and Type().
Most importantly, it also defines the Encode() and Decode() methods.
```
Encode() (io.Reader, int, error)
Decode(io.Reader) (int, error)
```
Encode returns an io.Reader in which the encoded bytes can be read. The second return
value is the number of bytes encoded, so the caller knows how many bytes there will be.
If Encode returns an error, then the first two return values should be considered invalid.
Any changes to the message after Encode() is called will invalidate the io.Reader.
Decode reads from the io.Reader parameter until a full message is decoded, or when io.Reader
returns EOF or error. The first return value is the number of bytes read from io.Reader.
The second is error if Decode encounters any problems.
With these in mind, we can now do:
```
// Create a new CONNECT message
msg := NewConnectMessage()
// Set the appropriate parameters
msg.SetWillQos(1)
msg.SetVersion(4)
msg.SetCleanSession(true)
msg.SetClientId([]byte("surgemq"))
msg.SetKeepAlive(10)
msg.SetWillTopic([]byte("will"))
msg.SetWillMessage([]byte("send me home"))
msg.SetUsername([]byte("surgemq"))
msg.SetPassword([]byte("verysecret"))
// Encode the message and get the io.Reader
r, n, err := msg.Encode()
if err == nil {
return err
}
// Write n bytes into the connection
m, err := io.CopyN(conn, r, int64(n))
if err != nil {
return err
}
fmt.Printf("Sent %d bytes of %s message", m, msg.Name())
```
To receive a CONNECT message from a connection, we can do:
```
// Create a new CONNECT message
msg := NewConnectMessage()
// Decode the message by reading from conn
n, err := msg.Decode(conn)
```
If you don't know what type of message is coming down the pipe, you can do something like this:
```
// Create a buffered IO reader for the connection
br := bufio.NewReader(conn)
// Peek at the first byte, which contains the message type
b, err := br.Peek(1)
if err != nil {
return err
}
// Extract the type from the first byte
t := MessageType(b[0] >> 4)
// Create a new message
msg, err := t.New()
if err != nil {
return err
}
// Decode it from the bufio.Reader
n, err := msg.Decode(br)
if err != nil {
return err
}
```

View File

@@ -1,168 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import "fmt"
// The CONNACK Packet is the packet sent by the Server in response to a CONNECT Packet
// received from a Client. The first packet sent from the Server to the Client MUST
// be a CONNACK Packet [MQTT-3.2.0-1].
//
// If the Client does not receive a CONNACK Packet from the Server within a reasonable
// amount of time, the Client SHOULD close the Network Connection. A "reasonable" amount
// of time depends on the type of application and the communications infrastructure.
type ConnackMessage struct {
header
sessionPresent bool
returnCode ConnackCode
}
var _ Message = (*ConnackMessage)(nil)
// NewConnackMessage creates a new CONNACK message
func NewConnackMessage() *ConnackMessage {
msg := &ConnackMessage{}
msg.SetType(CONNACK)
return msg
}
// String returns a string representation of the CONNACK message
func (this ConnackMessage) String() string {
return fmt.Sprintf("%s, Session Present=%t, Return code=%q\n", this.header, this.sessionPresent, this.returnCode)
}
// SessionPresent returns the session present flag value
func (this *ConnackMessage) SessionPresent() bool {
return this.sessionPresent
}
// SetSessionPresent sets the value of the session present flag
func (this *ConnackMessage) SetSessionPresent(v bool) {
if v {
this.sessionPresent = true
} else {
this.sessionPresent = false
}
this.dirty = true
}
// ReturnCode returns the return code received for the CONNECT message. The return
// type is an error
func (this *ConnackMessage) ReturnCode() ConnackCode {
return this.returnCode
}
func (this *ConnackMessage) SetReturnCode(ret ConnackCode) {
this.returnCode = ret
this.dirty = true
}
func (this *ConnackMessage) Len() int {
if !this.dirty {
return len(this.dbuf)
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0
}
return this.header.msglen() + ml
}
func (this *ConnackMessage) Decode(src []byte) (int, error) {
total := 0
n, err := this.header.decode(src)
total += n
if err != nil {
return total, err
}
b := src[total]
if b&254 != 0 {
return 0, fmt.Errorf("connack/Decode: Bits 7-1 in Connack Acknowledge Flags byte (1) are not 0")
}
this.sessionPresent = b&0x1 == 1
total++
b = src[total]
// Read return code
if b > 5 {
return 0, fmt.Errorf("connack/Decode: Invalid CONNACK return code (%d)", b)
}
this.returnCode = ConnackCode(b)
total++
this.dirty = false
return total, nil
}
func (this *ConnackMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("connack/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
// CONNACK remaining length fixed at 2 bytes
hl := this.header.msglen()
ml := this.msglen()
if len(dst) < hl+ml {
return 0, fmt.Errorf("connack/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
}
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0, err
}
total := 0
n, err := this.header.encode(dst[total:])
total += n
if err != nil {
return 0, err
}
if this.sessionPresent {
dst[total] = 1
}
total++
if this.returnCode > 5 {
return total, fmt.Errorf("connack/Encode: Invalid CONNACK return code (%d)", this.returnCode)
}
dst[total] = this.returnCode.Value()
total++
return total, nil
}
func (this *ConnackMessage) msglen() int {
return 2
}

View File

@@ -1,160 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestConnackMessageFields(t *testing.T) {
msg := NewConnackMessage()
msg.SetSessionPresent(true)
require.True(t, msg.SessionPresent(), "Error setting session present flag.")
msg.SetSessionPresent(false)
require.False(t, msg.SessionPresent(), "Error setting session present flag.")
msg.SetReturnCode(ConnectionAccepted)
require.Equal(t, ConnectionAccepted, msg.ReturnCode(), "Error setting return code.")
}
func TestConnackMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(CONNACK << 4),
2,
0, // session not present
0, // connection accepted
}
msg := NewConnackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.False(t, msg.SessionPresent(), "Error decoding session present flag.")
require.Equal(t, ConnectionAccepted, msg.ReturnCode(), "Error decoding return code.")
}
// testing wrong message length
func TestConnackMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(CONNACK << 4),
3,
0, // session not present
0, // connection accepted
}
msg := NewConnackMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err, "Error decoding message.")
}
// testing wrong message size
func TestConnackMessageDecode3(t *testing.T) {
msgBytes := []byte{
byte(CONNACK << 4),
2,
0, // session not present
}
msg := NewConnackMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err, "Error decoding message.")
}
// testing wrong reserve bits
func TestConnackMessageDecode4(t *testing.T) {
msgBytes := []byte{
byte(CONNACK << 4),
2,
64, // <- wrong size
0, // connection accepted
}
msg := NewConnackMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err, "Error decoding message.")
}
// testing invalid return code
func TestConnackMessageDecode5(t *testing.T) {
msgBytes := []byte{
byte(CONNACK << 4),
2,
0,
6, // <- wrong code
}
msg := NewConnackMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err, "Error decoding message.")
}
func TestConnackMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(CONNACK << 4),
2,
1, // session present
0, // connection accepted
}
msg := NewConnackMessage()
msg.SetReturnCode(ConnectionAccepted)
msg.SetSessionPresent(true)
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error encoding message.")
require.Equal(t, msgBytes, dst[:n], "Error encoding connack message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestConnackDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(CONNACK << 4),
2,
0, // session not present
0, // connection accepted
}
msg := NewConnackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,89 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
// ConnackCode is the type representing the return code in the CONNACK message,
// returned after the initial CONNECT message
type ConnackCode byte
const (
// Connection accepted
ConnectionAccepted ConnackCode = iota
// The Server does not support the level of the MQTT protocol requested by the Client
ErrInvalidProtocolVersion
// The Client identifier is correct UTF-8 but not allowed by the server
ErrIdentifierRejected
// The Network Connection has been made but the MQTT service is unavailable
ErrServerUnavailable
// The data in the user name or password is malformed
ErrBadUsernameOrPassword
// The Client is not authorized to connect
ErrNotAuthorized
)
// Value returns the value of the ConnackCode, which is just the byte representation
func (this ConnackCode) Value() byte {
return byte(this)
}
// Desc returns the description of the ConnackCode
func (this ConnackCode) Desc() string {
switch this {
case 0:
return "Connection accepted"
case 1:
return "The Server does not support the level of the MQTT protocol requested by the Client"
case 2:
return "The Client identifier is correct UTF-8 but not allowed by the server"
case 3:
return "The Network Connection has been made but the MQTT service is unavailable"
case 4:
return "The data in the user name or password is malformed"
case 5:
return "The Client is not authorized to connect"
}
return ""
}
// Valid checks to see if the ConnackCode is valid. Currently valid codes are <= 5
func (this ConnackCode) Valid() bool {
return this <= 5
}
// Error returns the corresonding error string for the ConnackCode
func (this ConnackCode) Error() string {
switch this {
case 0:
return "Connection accepted"
case 1:
return "Connection Refused, unacceptable protocol version"
case 2:
return "Connection Refused, identifier rejected"
case 3:
return "Connection Refused, Server unavailable"
case 4:
return "Connection Refused, bad user name or password"
case 5:
return "Connection Refused, not authorized"
}
return "Unknown error"
}

View File

@@ -1,635 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"encoding/binary"
"fmt"
"regexp"
)
var clientIdRegexp *regexp.Regexp
func init() {
// Added space for Paho compliance test
// Added underscore (_) for MQTT C client test
clientIdRegexp = regexp.MustCompile("^[0-9a-zA-Z _]*$")
}
// After a Network Connection is established by a Client to a Server, the first Packet
// sent from the Client to the Server MUST be a CONNECT Packet [MQTT-3.1.0-1].
//
// A Client can only send the CONNECT Packet once over a Network Connection. The Server
// MUST process a second CONNECT Packet sent from a Client as a protocol violation and
// disconnect the Client [MQTT-3.1.0-2]. See section 4.8 for information about
// handling errors.
type ConnectMessage struct {
header
// 7: username flag
// 6: password flag
// 5: will retain
// 4-3: will QoS
// 2: will flag
// 1: clean session
// 0: reserved
connectFlags byte
version byte
keepAlive uint16
protoName,
clientId,
willTopic,
willMessage,
username,
password []byte
}
var _ Message = (*ConnectMessage)(nil)
// NewConnectMessage creates a new CONNECT message.
func NewConnectMessage() *ConnectMessage {
msg := &ConnectMessage{}
msg.SetType(CONNECT)
return msg
}
// String returns a string representation of the CONNECT message
func (this ConnectMessage) String() string {
return fmt.Sprintf("%s, Connect Flags=%08b, Version=%d, KeepAlive=%d, Client ID=%q, Will Topic=%q, Will Message=%q, Username=%q, Password=%q",
this.header,
this.connectFlags,
this.Version(),
this.KeepAlive(),
this.ClientId(),
this.WillTopic(),
this.WillMessage(),
this.Username(),
this.Password(),
)
}
// Version returns the the 8 bit unsigned value that represents the revision level
// of the protocol used by the Client. The value of the Protocol Level field for
// the version 3.1.1 of the protocol is 4 (0x04).
func (this *ConnectMessage) Version() byte {
return this.version
}
// SetVersion sets the version value of the CONNECT message
func (this *ConnectMessage) SetVersion(v byte) error {
if _, ok := SupportedVersions[v]; !ok {
return fmt.Errorf("connect/SetVersion: Invalid version number %d", v)
}
this.version = v
this.dirty = true
return nil
}
// CleanSession returns the bit that specifies the handling of the Session state.
// The Client and Server can store Session state to enable reliable messaging to
// continue across a sequence of Network Connections. This bit is used to control
// the lifetime of the Session state.
func (this *ConnectMessage) CleanSession() bool {
return ((this.connectFlags >> 1) & 0x1) == 1
}
// SetCleanSession sets the bit that specifies the handling of the Session state.
func (this *ConnectMessage) SetCleanSession(v bool) {
if v {
this.connectFlags |= 0x2 // 00000010
} else {
this.connectFlags &= 253 // 11111101
}
this.dirty = true
}
// WillFlag returns the bit that specifies whether a Will Message should be stored
// on the server. If the Will Flag is set to 1 this indicates that, if the Connect
// request is accepted, a Will Message MUST be stored on the Server and associated
// with the Network Connection.
func (this *ConnectMessage) WillFlag() bool {
return ((this.connectFlags >> 2) & 0x1) == 1
}
// SetWillFlag sets the bit that specifies whether a Will Message should be stored
// on the server.
func (this *ConnectMessage) SetWillFlag(v bool) {
if v {
this.connectFlags |= 0x4 // 00000100
} else {
this.connectFlags &= 251 // 11111011
}
this.dirty = true
}
// WillQos returns the two bits that specify the QoS level to be used when publishing
// the Will Message.
func (this *ConnectMessage) WillQos() byte {
return (this.connectFlags >> 3) & 0x3
}
// SetWillQos sets the two bits that specify the QoS level to be used when publishing
// the Will Message.
func (this *ConnectMessage) SetWillQos(qos byte) error {
if qos != QosAtMostOnce && qos != QosAtLeastOnce && qos != QosExactlyOnce {
return fmt.Errorf("connect/SetWillQos: Invalid QoS level %d", qos)
}
this.connectFlags = (this.connectFlags & 231) | (qos << 3) // 231 = 11100111
this.dirty = true
return nil
}
// WillRetain returns the bit specifies if the Will Message is to be Retained when it
// is published.
func (this *ConnectMessage) WillRetain() bool {
return ((this.connectFlags >> 5) & 0x1) == 1
}
// SetWillRetain sets the bit specifies if the Will Message is to be Retained when it
// is published.
func (this *ConnectMessage) SetWillRetain(v bool) {
if v {
this.connectFlags |= 32 // 00100000
} else {
this.connectFlags &= 223 // 11011111
}
this.dirty = true
}
// UsernameFlag returns the bit that specifies whether a user name is present in the
// payload.
func (this *ConnectMessage) UsernameFlag() bool {
return ((this.connectFlags >> 7) & 0x1) == 1
}
// SetUsernameFlag sets the bit that specifies whether a user name is present in the
// payload.
func (this *ConnectMessage) SetUsernameFlag(v bool) {
if v {
this.connectFlags |= 128 // 10000000
} else {
this.connectFlags &= 127 // 01111111
}
this.dirty = true
}
// PasswordFlag returns the bit that specifies whether a password is present in the
// payload.
func (this *ConnectMessage) PasswordFlag() bool {
return ((this.connectFlags >> 6) & 0x1) == 1
}
// SetPasswordFlag sets the bit that specifies whether a password is present in the
// payload.
func (this *ConnectMessage) SetPasswordFlag(v bool) {
if v {
this.connectFlags |= 64 // 01000000
} else {
this.connectFlags &= 191 // 10111111
}
this.dirty = true
}
// KeepAlive returns a time interval measured in seconds. Expressed as a 16-bit word,
// it is the maximum time interval that is permitted to elapse between the point at
// which the Client finishes transmitting one Control Packet and the point it starts
// sending the next.
func (this *ConnectMessage) KeepAlive() uint16 {
return this.keepAlive
}
// SetKeepAlive sets the time interval in which the server should keep the connection
// alive.
func (this *ConnectMessage) SetKeepAlive(v uint16) {
this.keepAlive = v
this.dirty = true
}
// ClientId returns an ID that identifies the Client to the Server. Each Client
// connecting to the Server has a unique ClientId. The ClientId MUST be used by
// Clients and by Servers to identify state that they hold relating to this MQTT
// Session between the Client and the Server
func (this *ConnectMessage) ClientId() []byte {
return this.clientId
}
// SetClientId sets an ID that identifies the Client to the Server.
func (this *ConnectMessage) SetClientId(v []byte) error {
if len(v) > 0 && !this.validClientId(v) {
return ErrIdentifierRejected
}
this.clientId = v
this.dirty = true
return nil
}
// WillTopic returns the topic in which the Will Message should be published to.
// If the Will Flag is set to 1, the Will Topic must be in the payload.
func (this *ConnectMessage) WillTopic() []byte {
return this.willTopic
}
// SetWillTopic sets the topic in which the Will Message should be published to.
func (this *ConnectMessage) SetWillTopic(v []byte) {
this.willTopic = v
if len(v) > 0 {
this.SetWillFlag(true)
} else if len(this.willMessage) == 0 {
this.SetWillFlag(false)
}
this.dirty = true
}
// WillMessage returns the Will Message that is to be published to the Will Topic.
func (this *ConnectMessage) WillMessage() []byte {
return this.willMessage
}
// SetWillMessage sets the Will Message that is to be published to the Will Topic.
func (this *ConnectMessage) SetWillMessage(v []byte) {
this.willMessage = v
if len(v) > 0 {
this.SetWillFlag(true)
} else if len(this.willTopic) == 0 {
this.SetWillFlag(false)
}
this.dirty = true
}
// Username returns the username from the payload. If the User Name Flag is set to 1,
// this must be in the payload. It can be used by the Server for authentication and
// authorization.
func (this *ConnectMessage) Username() []byte {
return this.username
}
// SetUsername sets the username for authentication.
func (this *ConnectMessage) SetUsername(v []byte) {
this.username = v
if len(v) > 0 {
this.SetUsernameFlag(true)
} else {
this.SetUsernameFlag(false)
}
this.dirty = true
}
// Password returns the password from the payload. If the Password Flag is set to 1,
// this must be in the payload. It can be used by the Server for authentication and
// authorization.
func (this *ConnectMessage) Password() []byte {
return this.password
}
// SetPassword sets the username for authentication.
func (this *ConnectMessage) SetPassword(v []byte) {
this.password = v
if len(v) > 0 {
this.SetPasswordFlag(true)
} else {
this.SetPasswordFlag(false)
}
this.dirty = true
}
func (this *ConnectMessage) Len() int {
if !this.dirty {
return len(this.dbuf)
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0
}
return this.header.msglen() + ml
}
// For the CONNECT message, the error returned could be a ConnackReturnCode, so
// be sure to check that. Otherwise it's a generic error. If a generic error is
// returned, this Message should be considered invalid.
//
// Caller should call ValidConnackError(err) to see if the returned error is
// a Connack error. If so, caller should send the Client back the corresponding
// CONNACK message.
func (this *ConnectMessage) Decode(src []byte) (int, error) {
total := 0
n, err := this.header.decode(src[total:])
if err != nil {
return total + n, err
}
total += n
if n, err = this.decodeMessage(src[total:]); err != nil {
return total + n, err
}
total += n
this.dirty = false
return total, nil
}
func (this *ConnectMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("connect/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
if this.Type() != CONNECT {
return 0, fmt.Errorf("connect/Encode: Invalid message type. Expecting %d, got %d", CONNECT, this.Type())
}
_, ok := SupportedVersions[this.version]
if !ok {
return 0, ErrInvalidProtocolVersion
}
hl := this.header.msglen()
ml := this.msglen()
if len(dst) < hl+ml {
return 0, fmt.Errorf("connect/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
}
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0, err
}
total := 0
n, err := this.header.encode(dst[total:])
total += n
if err != nil {
return total, err
}
n, err = this.encodeMessage(dst[total:])
total += n
if err != nil {
return total, err
}
return total, nil
}
func (this *ConnectMessage) encodeMessage(dst []byte) (int, error) {
total := 0
n, err := writeLPBytes(dst[total:], []byte(SupportedVersions[this.version]))
total += n
if err != nil {
return total, err
}
dst[total] = this.version
total += 1
dst[total] = this.connectFlags
total += 1
binary.BigEndian.PutUint16(dst[total:], this.keepAlive)
total += 2
n, err = writeLPBytes(dst[total:], this.clientId)
total += n
if err != nil {
return total, err
}
if this.WillFlag() {
n, err = writeLPBytes(dst[total:], this.willTopic)
total += n
if err != nil {
return total, err
}
n, err = writeLPBytes(dst[total:], this.willMessage)
total += n
if err != nil {
return total, err
}
}
// According to the 3.1 spec, it's possible that the usernameFlag is set,
// but the username string is missing.
if this.UsernameFlag() && len(this.username) > 0 {
n, err = writeLPBytes(dst[total:], this.username)
total += n
if err != nil {
return total, err
}
}
// According to the 3.1 spec, it's possible that the passwordFlag is set,
// but the password string is missing.
if this.PasswordFlag() && len(this.password) > 0 {
n, err = writeLPBytes(dst[total:], this.password)
total += n
if err != nil {
return total, err
}
}
return total, nil
}
func (this *ConnectMessage) decodeMessage(src []byte) (int, error) {
var err error
n, total := 0, 0
this.protoName, n, err = readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
this.version = src[total]
total++
if verstr, ok := SupportedVersions[this.version]; !ok {
return total, ErrInvalidProtocolVersion
} else if verstr != string(this.protoName) {
return total, ErrInvalidProtocolVersion
}
this.connectFlags = src[total]
total++
if this.connectFlags&0x1 != 0 {
return total, fmt.Errorf("connect/decodeMessage: Connect Flags reserved bit 0 is not 0")
}
if this.WillQos() > QosExactlyOnce {
return total, fmt.Errorf("connect/decodeMessage: Invalid QoS level (%d) for %s message", this.WillQos(), this.Name())
}
if !this.WillFlag() && (this.WillRetain() || this.WillQos() != QosAtMostOnce) {
return total, fmt.Errorf("connect/decodeMessage: Protocol violation: If the Will Flag (%t) is set to 0 the Will QoS (%d) and Will Retain (%t) fields MUST be set to zero", this.WillFlag(), this.WillQos(), this.WillRetain())
}
if this.UsernameFlag() && !this.PasswordFlag() {
return total, fmt.Errorf("connect/decodeMessage: Username flag is set but Password flag is not set")
}
if len(src[total:]) < 2 {
return 0, fmt.Errorf("connect/decodeMessage: Insufficient buffer size. Expecting %d, got %d.", 2, len(src[total:]))
}
this.keepAlive = binary.BigEndian.Uint16(src[total:])
total += 2
this.clientId, n, err = readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
// If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession to 1
if len(this.clientId) == 0 && !this.CleanSession() {
return total, ErrIdentifierRejected
}
// The ClientId must contain only characters 0-9, a-z, and A-Z
// We also support ClientId longer than 23 encoded bytes
// We do not support ClientId outside of the above characters
if len(this.clientId) > 0 && !this.validClientId(this.clientId) {
return total, ErrIdentifierRejected
}
if this.WillFlag() {
this.willTopic, n, err = readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
this.willMessage, n, err = readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
}
// According to the 3.1 spec, it's possible that the passwordFlag is set,
// but the password string is missing.
if this.UsernameFlag() && len(src[total:]) > 0 {
this.username, n, err = readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
}
// According to the 3.1 spec, it's possible that the passwordFlag is set,
// but the password string is missing.
if this.PasswordFlag() && len(src[total:]) > 0 {
this.password, n, err = readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
}
return total, nil
}
func (this *ConnectMessage) msglen() int {
total := 0
verstr, ok := SupportedVersions[this.version]
if !ok {
return total
}
// 2 bytes protocol name length
// n bytes protocol name
// 1 byte protocol version
// 1 byte connect flags
// 2 bytes keep alive timer
total += 2 + len(verstr) + 1 + 1 + 2
// Add the clientID length, 2 is the length prefix
total += 2 + len(this.clientId)
// Add the will topic and will message length, and the length prefixes
if this.WillFlag() {
total += 2 + len(this.willTopic) + 2 + len(this.willMessage)
}
// Add the username length
// According to the 3.1 spec, it's possible that the usernameFlag is set,
// but the user name string is missing.
if this.UsernameFlag() && len(this.username) > 0 {
total += 2 + len(this.username)
}
// Add the password length
// According to the 3.1 spec, it's possible that the passwordFlag is set,
// but the password string is missing.
if this.PasswordFlag() && len(this.password) > 0 {
total += 2 + len(this.password)
}
return total
}
// validClientId checks the client ID, which is a slice of bytes, to see if it's valid.
// Client ID is valid if it meets the requirement from the MQTT spec:
// The Server MUST allow ClientIds which are between 1 and 23 UTF-8 encoded bytes in length,
// and that contain only the characters
//
// "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
func (this *ConnectMessage) validClientId(cid []byte) bool {
// Fixed https://github.com/surgemq/surgemq/issues/4
//if len(cid) > 23 {
// return false
//}
if this.Version() == 0x3 {
return true
}
return clientIdRegexp.Match(cid)
}

View File

@@ -1,373 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestConnectMessageFields(t *testing.T) {
msg := NewConnectMessage()
err := msg.SetVersion(0x3)
require.NoError(t, err, "Error setting message version.")
require.Equal(t, 0x3, int(msg.Version()), "Incorrect version number")
err = msg.SetVersion(0x5)
require.Error(t, err)
msg.SetCleanSession(true)
require.True(t, msg.CleanSession(), "Error setting clean session flag.")
msg.SetCleanSession(false)
require.False(t, msg.CleanSession(), "Error setting clean session flag.")
msg.SetWillFlag(true)
require.True(t, msg.WillFlag(), "Error setting will flag.")
msg.SetWillFlag(false)
require.False(t, msg.WillFlag(), "Error setting will flag.")
msg.SetWillRetain(true)
require.True(t, msg.WillRetain(), "Error setting will retain.")
msg.SetWillRetain(false)
require.False(t, msg.WillRetain(), "Error setting will retain.")
msg.SetPasswordFlag(true)
require.True(t, msg.PasswordFlag(), "Error setting password flag.")
msg.SetPasswordFlag(false)
require.False(t, msg.PasswordFlag(), "Error setting password flag.")
msg.SetUsernameFlag(true)
require.True(t, msg.UsernameFlag(), "Error setting username flag.")
msg.SetUsernameFlag(false)
require.False(t, msg.UsernameFlag(), "Error setting username flag.")
msg.SetWillQos(1)
require.Equal(t, 1, int(msg.WillQos()), "Error setting will QoS.")
err = msg.SetWillQos(4)
require.Error(t, err)
err = msg.SetClientId([]byte("j0j0jfajf02j0asdjf"))
require.NoError(t, err, "Error setting client ID")
require.Equal(t, "j0j0jfajf02j0asdjf", string(msg.ClientId()), "Error setting client ID.")
err = msg.SetClientId([]byte("this is good for v3"))
require.NoError(t, err)
msg.SetVersion(0x4)
err = msg.SetClientId([]byte("this is no good for v4!"))
require.Error(t, err)
msg.SetVersion(0x3)
msg.SetWillTopic([]byte("willtopic"))
require.Equal(t, "willtopic", string(msg.WillTopic()), "Error setting will topic.")
require.True(t, msg.WillFlag(), "Error setting will flag.")
msg.SetWillTopic([]byte(""))
require.Equal(t, "", string(msg.WillTopic()), "Error setting will topic.")
require.False(t, msg.WillFlag(), "Error setting will flag.")
msg.SetWillMessage([]byte("this is a will message"))
require.Equal(t, "this is a will message", string(msg.WillMessage()), "Error setting will message.")
require.True(t, msg.WillFlag(), "Error setting will flag.")
msg.SetWillMessage([]byte(""))
require.Equal(t, "", string(msg.WillMessage()), "Error setting will topic.")
require.False(t, msg.WillFlag(), "Error setting will flag.")
msg.SetWillTopic([]byte("willtopic"))
msg.SetWillMessage([]byte("this is a will message"))
msg.SetWillTopic([]byte(""))
require.True(t, msg.WillFlag(), "Error setting will topic.")
msg.SetUsername([]byte("myname"))
require.Equal(t, "myname", string(msg.Username()), "Error setting will message.")
require.True(t, msg.UsernameFlag(), "Error setting will flag.")
msg.SetUsername([]byte(""))
require.Equal(t, "", string(msg.Username()), "Error setting will message.")
require.False(t, msg.UsernameFlag(), "Error setting will flag.")
msg.SetPassword([]byte("myname"))
require.Equal(t, "myname", string(msg.Password()), "Error setting will message.")
require.True(t, msg.PasswordFlag(), "Error setting will flag.")
msg.SetPassword([]byte(""))
require.Equal(t, "", string(msg.Password()), "Error setting will message.")
require.False(t, msg.PasswordFlag(), "Error setting will flag.")
}
func TestConnectMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(CONNECT << 4),
60,
0, // Length MSB (0)
4, // Length LSB (4)
'M', 'Q', 'T', 'T',
4, // Protocol level 4
206, // connect flags 11001110, will QoS = 01
0, // Keep Alive MSB (0)
10, // Keep Alive LSB (10)
0, // Client ID MSB (0)
7, // Client ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Will Topic MSB (0)
4, // Will Topic LSB (4)
'w', 'i', 'l', 'l',
0, // Will Message MSB (0)
12, // Will Message LSB (12)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
0, // Username ID MSB (0)
7, // Username ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Password ID MSB (0)
10, // Password ID LSB (10)
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
}
msg := NewConnectMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, 206, int(msg.connectFlags), "Incorrect flag value.")
require.Equal(t, 10, int(msg.KeepAlive()), "Incorrect KeepAlive value.")
require.Equal(t, "surgemq", string(msg.ClientId()), "Incorrect client ID value.")
require.Equal(t, "will", string(msg.WillTopic()), "Incorrect will topic value.")
require.Equal(t, "send me home", string(msg.WillMessage()), "Incorrect will message value.")
require.Equal(t, "surgemq", string(msg.Username()), "Incorrect username value.")
require.Equal(t, "verysecret", string(msg.Password()), "Incorrect password value.")
}
func TestConnectMessageDecode2(t *testing.T) {
// missing last byte 't'
msgBytes := []byte{
byte(CONNECT << 4),
60,
0, // Length MSB (0)
4, // Length LSB (4)
'M', 'Q', 'T', 'T',
4, // Protocol level 4
206, // connect flags 11001110, will QoS = 01
0, // Keep Alive MSB (0)
10, // Keep Alive LSB (10)
0, // Client ID MSB (0)
7, // Client ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Will Topic MSB (0)
4, // Will Topic LSB (4)
'w', 'i', 'l', 'l',
0, // Will Message MSB (0)
12, // Will Message LSB (12)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
0, // Username ID MSB (0)
7, // Username ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Password ID MSB (0)
10, // Password ID LSB (10)
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e',
}
msg := NewConnectMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestConnectMessageDecode3(t *testing.T) {
// extra bytes
msgBytes := []byte{
byte(CONNECT << 4),
60,
0, // Length MSB (0)
4, // Length LSB (4)
'M', 'Q', 'T', 'T',
4, // Protocol level 4
206, // connect flags 11001110, will QoS = 01
0, // Keep Alive MSB (0)
10, // Keep Alive LSB (10)
0, // Client ID MSB (0)
7, // Client ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Will Topic MSB (0)
4, // Will Topic LSB (4)
'w', 'i', 'l', 'l',
0, // Will Message MSB (0)
12, // Will Message LSB (12)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
0, // Username ID MSB (0)
7, // Username ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Password ID MSB (0)
10, // Password ID LSB (10)
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
'e', 'x', 't', 'r', 'a',
}
msg := NewConnectMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err)
require.Equal(t, 62, n)
}
func TestConnectMessageDecode4(t *testing.T) {
// missing client Id, clean session == 0
msgBytes := []byte{
byte(CONNECT << 4),
53,
0, // Length MSB (0)
4, // Length LSB (4)
'M', 'Q', 'T', 'T',
4, // Protocol level 4
204, // connect flags 11001110, will QoS = 01
0, // Keep Alive MSB (0)
10, // Keep Alive LSB (10)
0, // Client ID MSB (0)
0, // Client ID LSB (0)
0, // Will Topic MSB (0)
4, // Will Topic LSB (4)
'w', 'i', 'l', 'l',
0, // Will Message MSB (0)
12, // Will Message LSB (12)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
0, // Username ID MSB (0)
7, // Username ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Password ID MSB (0)
10, // Password ID LSB (10)
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
}
msg := NewConnectMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestConnectMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(CONNECT << 4),
60,
0, // Length MSB (0)
4, // Length LSB (4)
'M', 'Q', 'T', 'T',
4, // Protocol level 4
206, // connect flags 11001110, will QoS = 01
0, // Keep Alive MSB (0)
10, // Keep Alive LSB (10)
0, // Client ID MSB (0)
7, // Client ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Will Topic MSB (0)
4, // Will Topic LSB (4)
'w', 'i', 'l', 'l',
0, // Will Message MSB (0)
12, // Will Message LSB (12)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
0, // Username ID MSB (0)
7, // Username ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Password ID MSB (0)
10, // Password ID LSB (10)
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
}
msg := NewConnectMessage()
msg.SetWillQos(1)
msg.SetVersion(4)
msg.SetCleanSession(true)
msg.SetClientId([]byte("surgemq"))
msg.SetKeepAlive(10)
msg.SetWillTopic([]byte("will"))
msg.SetWillMessage([]byte("send me home"))
msg.SetUsername([]byte("surgemq"))
msg.SetPassword([]byte("verysecret"))
dst := make([]byte, 100)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestConnectDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(CONNECT << 4),
60,
0, // Length MSB (0)
4, // Length LSB (4)
'M', 'Q', 'T', 'T',
4, // Protocol level 4
206, // connect flags 11001110, will QoS = 01
0, // Keep Alive MSB (0)
10, // Keep Alive LSB (10)
0, // Client ID MSB (0)
7, // Client ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Will Topic MSB (0)
4, // Will Topic LSB (4)
'w', 'i', 'l', 'l',
0, // Will Message MSB (0)
12, // Will Message LSB (12)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
0, // Username ID MSB (0)
7, // Username ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Password ID MSB (0)
10, // Password ID LSB (10)
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
}
msg := NewConnectMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,49 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import "fmt"
// The DISCONNECT Packet is the final Control Packet sent from the Client to the Server.
// It indicates that the Client is disconnecting cleanly.
type DisconnectMessage struct {
header
}
var _ Message = (*DisconnectMessage)(nil)
// NewDisconnectMessage creates a new DISCONNECT message.
func NewDisconnectMessage() *DisconnectMessage {
msg := &DisconnectMessage{}
msg.SetType(DISCONNECT)
return msg
}
func (this *DisconnectMessage) Decode(src []byte) (int, error) {
return this.header.decode(src)
}
func (this *DisconnectMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("disconnect/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
return this.header.encode(dst)
}

View File

@@ -1,78 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestDisconnectMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(DISCONNECT << 4),
0,
}
msg := NewDisconnectMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, DISCONNECT, msg.Type(), "Error decoding message.")
}
func TestDisconnectMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(DISCONNECT << 4),
0,
}
msg := NewDisconnectMessage()
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestDisconnectDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(DISCONNECT << 4),
0,
}
msg := NewDisconnectMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,141 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
Package message is an encoder/decoder library for MQTT 3.1 and 3.1.1 messages. You can
find the MQTT specs at the following locations:
3.1.1 - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/
3.1 - http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
From the spec:
MQTT is a Client Server publish/subscribe messaging transport protocol. It is
light weight, open, simple, and designed so as to be easy to implement. These
characteristics make it ideal for use in many situations, including constrained
environments such as for communication in Machine to Machine (M2M) and Internet
of Things (IoT) contexts where a small code footprint is required and/or network
bandwidth is at a premium.
The MQTT protocol works by exchanging a series of MQTT messages in a defined way.
The protocol runs over TCP/IP, or over other network protocols that provide
ordered, lossless, bi-directional connections.
There are two main items to take note in this package. The first is
type MessageType byte
MessageType is the type representing the MQTT packet types. In the MQTT spec, MQTT
control packet type is represented as a 4-bit unsigned value. MessageType receives
several methods that returns string representations of the names and descriptions.
Also, one of the methods is New(). It returns a new Message object based on the mtype
parameter. For example:
m, err := CONNECT.New()
msg := m.(*ConnectMessage)
This would return a PublishMessage struct, but mapped to the Message interface. You can
then type assert it back to a *PublishMessage. Another way to create a new
PublishMessage is to call
msg := NewConnectMessage()
Every message type has a New function that returns a new message. The list of available
message types are defined as constants below.
As you may have noticed, the second important item is the Message interface. It defines
several methods that are common to all messages, including Name(), Desc(), and Type().
Most importantly, it also defines the Encode() and Decode() methods.
Encode() (io.Reader, int, error)
Decode(io.Reader) (int, error)
Encode returns an io.Reader in which the encoded bytes can be read. The second return
value is the number of bytes encoded, so the caller knows how many bytes there will be.
If Encode returns an error, then the first two return values should be considered invalid.
Any changes to the message after Encode() is called will invalidate the io.Reader.
Decode reads from the io.Reader parameter until a full message is decoded, or when io.Reader
returns EOF or error. The first return value is the number of bytes read from io.Reader.
The second is error if Decode encounters any problems.
With these in mind, we can now do:
// Create a new CONNECT message
msg := NewConnectMessage()
// Set the appropriate parameters
msg.SetWillQos(1)
msg.SetVersion(4)
msg.SetCleanSession(true)
msg.SetClientId([]byte("surgemq"))
msg.SetKeepAlive(10)
msg.SetWillTopic([]byte("will"))
msg.SetWillMessage([]byte("send me home"))
msg.SetUsername([]byte("surgemq"))
msg.SetPassword([]byte("verysecret"))
// Encode the message and get the io.Reader
r, n, err := msg.Encode()
if err == nil {
return err
}
// Write n bytes into the connection
m, err := io.CopyN(conn, r, int64(n))
if err != nil {
return err
}
fmt.Printf("Sent %d bytes of %s message", m, msg.Name())
To receive a CONNECT message from a connection, we can do:
// Create a new CONNECT message
msg := NewConnectMessage()
// Decode the message by reading from conn
n, err := msg.Decode(conn)
If you don't know what type of message is coming down the pipe, you can do something like this:
// Create a buffered IO reader for the connection
br := bufio.NewReader(conn)
// Peek at the first byte, which contains the message type
b, err := br.Peek(1)
if err != nil {
return err
}
// Extract the type from the first byte
t := MessageType(b[0] >> 4)
// Create a new message
msg, err := t.New()
if err != nil {
return err
}
// Decode it from the bufio.Reader
n, err := msg.Decode(br)
if err != nil {
return err
}
*/
package message

View File

@@ -1,248 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"encoding/binary"
"fmt"
)
var (
gPacketId uint64 = 0
)
// Fixed header
// - 1 byte for control packet type (bits 7-4) and flags (bits 3-0)
// - up to 4 byte for remaining length
type header struct {
// Header fields
//mtype MessageType
//flags byte
remlen int32
// mtypeflags is the first byte of the buffer, 4 bits for mtype, 4 bits for flags
mtypeflags []byte
// Some messages need packet ID, 2 byte uint16
packetId []byte
// Points to the decoding buffer
dbuf []byte
// Whether the message has changed since last decode
dirty bool
}
// String returns a string representation of the message.
func (this header) String() string {
return fmt.Sprintf("Type=%q, Flags=%08b, Remaining Length=%d", this.Type().Name(), this.Flags(), this.remlen)
}
// Name returns a string representation of the message type. Examples include
// "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of
// the message types and cannot be changed.
func (this *header) Name() string {
return this.Type().Name()
}
// Desc returns a string description of the message type. For example, a
// CONNECT message would return "Client request to connect to Server." These
// descriptions are statically defined (copied from the MQTT spec) and cannot
// be changed.
func (this *header) Desc() string {
return this.Type().Desc()
}
// Type returns the MessageType of the Message. The retured value should be one
// of the constants defined for MessageType.
func (this *header) Type() MessageType {
//return this.mtype
if len(this.mtypeflags) != 1 {
this.mtypeflags = make([]byte, 1)
this.dirty = true
}
return MessageType(this.mtypeflags[0] >> 4)
}
// SetType sets the message type of this message. It also correctly sets the
// default flags for the message type. It returns an error if the type is invalid.
func (this *header) SetType(mtype MessageType) error {
if !mtype.Valid() {
return fmt.Errorf("header/SetType: Invalid control packet type %d", mtype)
}
// Notice we don't set the message to be dirty when we are not allocating a new
// buffer. In this case, it means the buffer is probably a sub-slice of another
// slice. If that's the case, then during encoding we would have copied the whole
// backing buffer anyway.
if len(this.mtypeflags) != 1 {
this.mtypeflags = make([]byte, 1)
this.dirty = true
}
this.mtypeflags[0] = byte(mtype)<<4 | (mtype.DefaultFlags() & 0xf)
return nil
}
// Flags returns the fixed header flags for this message.
func (this *header) Flags() byte {
//return this.flags
return this.mtypeflags[0] & 0x0f
}
// RemainingLength returns the length of the non-fixed-header part of the message.
func (this *header) RemainingLength() int32 {
return this.remlen
}
// SetRemainingLength sets the length of the non-fixed-header part of the message.
// It returns error if the length is greater than 268435455, which is the max
// message length as defined by the MQTT spec.
func (this *header) SetRemainingLength(remlen int32) error {
if remlen > maxRemainingLength || remlen < 0 {
return fmt.Errorf("header/SetLength: Remaining length (%d) out of bound (max %d, min 0)", remlen, maxRemainingLength)
}
this.remlen = remlen
this.dirty = true
return nil
}
func (this *header) Len() int {
return this.msglen()
}
// PacketId returns the ID of the packet.
func (this *header) PacketId() uint16 {
if len(this.packetId) == 2 {
return binary.BigEndian.Uint16(this.packetId)
}
return 0
}
// SetPacketId sets the ID of the packet.
func (this *header) SetPacketId(v uint16) {
// If setting to 0, nothing to do, move on
if v == 0 {
return
}
// If packetId buffer is not 2 bytes (uint16), then we allocate a new one and
// make dirty. Then we encode the packet ID into the buffer.
if len(this.packetId) != 2 {
this.packetId = make([]byte, 2)
this.dirty = true
}
// Notice we don't set the message to be dirty when we are not allocating a new
// buffer. In this case, it means the buffer is probably a sub-slice of another
// slice. If that's the case, then during encoding we would have copied the whole
// backing buffer anyway.
binary.BigEndian.PutUint16(this.packetId, v)
}
func (this *header) encode(dst []byte) (int, error) {
ml := this.msglen()
if len(dst) < ml {
return 0, fmt.Errorf("header/Encode: Insufficient buffer size. Expecting %d, got %d.", ml, len(dst))
}
total := 0
if this.remlen > maxRemainingLength || this.remlen < 0 {
return total, fmt.Errorf("header/Encode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength)
}
if !this.Type().Valid() {
return total, fmt.Errorf("header/Encode: Invalid message type %d", this.Type())
}
dst[total] = this.mtypeflags[0]
total += 1
n := binary.PutUvarint(dst[total:], uint64(this.remlen))
total += n
return total, nil
}
// Decode reads from the io.Reader parameter until a full message is decoded, or
// when io.Reader returns EOF or error. The first return value is the number of
// bytes read from io.Reader. The second is error if Decode encounters any problems.
func (this *header) decode(src []byte) (int, error) {
total := 0
this.dbuf = src
mtype := this.Type()
//mtype := MessageType(0)
this.mtypeflags = src[total : total+1]
//mtype := MessageType(src[total] >> 4)
if !this.Type().Valid() {
return total, fmt.Errorf("header/Decode: Invalid message type %d.", mtype)
}
if mtype != this.Type() {
return total, fmt.Errorf("header/Decode: Invalid message type %d. Expecting %d.", this.Type(), mtype)
}
//this.flags = src[total] & 0x0f
if this.Type() != PUBLISH && this.Flags() != this.Type().DefaultFlags() {
return total, fmt.Errorf("header/Decode: Invalid message (%d) flags. Expecting %d, got %d", this.Type(), this.Type().DefaultFlags(), this.Flags())
}
if this.Type() == PUBLISH && !ValidQos((this.Flags()>>1)&0x3) {
return total, fmt.Errorf("header/Decode: Invalid QoS (%d) for PUBLISH message.", (this.Flags()>>1)&0x3)
}
total++
remlen, m := binary.Uvarint(src[total:])
total += m
this.remlen = int32(remlen)
if this.remlen > maxRemainingLength || remlen < 0 {
return total, fmt.Errorf("header/Decode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength)
}
if int(this.remlen) > len(src[total:]) {
return total, fmt.Errorf("header/Decode: Remaining length (%d) is greater than remaining buffer (%d)", this.remlen, len(src[total:]))
}
return total, nil
}
func (this *header) msglen() int {
// message type and flag byte
total := 1
if this.remlen <= 127 {
total += 1
} else if this.remlen <= 16383 {
total += 2
} else if this.remlen <= 2097151 {
total += 3
} else {
total += 4
}
return total
}

View File

@@ -1,182 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMessageHeaderFields(t *testing.T) {
header := &header{}
header.SetRemainingLength(33)
require.Equal(t, int32(33), header.RemainingLength())
err := header.SetRemainingLength(268435456)
require.Error(t, err)
err = header.SetRemainingLength(-1)
require.Error(t, err)
err = header.SetType(RESERVED)
require.Error(t, err)
err = header.SetType(PUBREL)
require.NoError(t, err)
require.Equal(t, PUBREL, header.Type())
require.Equal(t, "PUBREL", header.Name())
require.Equal(t, 2, int(header.Flags()))
}
// Not enough bytes
func TestMessageHeaderDecode(t *testing.T) {
buf := []byte{0x6f, 193, 2}
header := &header{}
_, err := header.decode(buf)
require.Error(t, err)
}
// Remaining length too big
func TestMessageHeaderDecode2(t *testing.T) {
buf := []byte{0x62, 0xff, 0xff, 0xff, 0xff}
header := &header{}
_, err := header.decode(buf)
require.Error(t, err)
}
func TestMessageHeaderDecode3(t *testing.T) {
buf := []byte{0x62, 0xff}
header := &header{}
_, err := header.decode(buf)
require.Error(t, err)
}
func TestMessageHeaderDecode4(t *testing.T) {
buf := []byte{0x62, 0xff, 0xff, 0xff, 0x7f}
header := &header{
mtypeflags: []byte{6<<4 | 2},
//mtype: 6,
//flags: 2,
}
n, err := header.decode(buf)
require.Error(t, err)
require.Equal(t, 5, n)
require.Equal(t, maxRemainingLength, header.RemainingLength())
}
func TestMessageHeaderDecode5(t *testing.T) {
buf := []byte{0x62, 0xff, 0x7f}
header := &header{
mtypeflags: []byte{6<<4 | 2},
//mtype: 6,
//flags: 2,
}
n, err := header.decode(buf)
require.Error(t, err)
require.Equal(t, 3, n)
}
func TestMessageHeaderEncode1(t *testing.T) {
header := &header{}
headerBytes := []byte{0x62, 193, 2}
err := header.SetType(PUBREL)
require.NoError(t, err)
err = header.SetRemainingLength(321)
require.NoError(t, err)
buf := make([]byte, 3)
n, err := header.encode(buf)
require.NoError(t, err)
require.Equal(t, 3, n)
require.Equal(t, headerBytes, buf)
}
func TestMessageHeaderEncode2(t *testing.T) {
header := &header{}
err := header.SetType(PUBREL)
require.NoError(t, err)
header.remlen = 268435456
buf := make([]byte, 5)
_, err = header.encode(buf)
require.Error(t, err)
}
func TestMessageHeaderEncode3(t *testing.T) {
header := &header{}
headerBytes := []byte{0x62, 0xff, 0xff, 0xff, 0x7f}
err := header.SetType(PUBREL)
require.NoError(t, err)
err = header.SetRemainingLength(maxRemainingLength)
require.NoError(t, err)
buf := make([]byte, 5)
n, err := header.encode(buf)
require.NoError(t, err)
require.Equal(t, 5, n)
require.Equal(t, headerBytes, buf)
}
func TestMessageHeaderEncode4(t *testing.T) {
header := &header{
mtypeflags: []byte{byte(RESERVED2) << 4},
//mtype: 6,
//flags: 2,
}
buf := make([]byte, 5)
_, err := header.encode(buf)
require.Error(t, err)
}
/*
// This test is to ensure that an empty message is at least 2 bytes long
func TestMessageHeaderEncode5(t *testing.T) {
msg := NewPingreqMessage()
dst, n, err := msg.encode()
if err != nil {
t.Errorf("Error encoding PINGREQ message: %v", err)
} else if n != 2 {
t.Errorf("Incorrect result. Expecting length of 2 bytes, got %d.", dst.(*bytes.Buffer).Len())
}
}
*/

View File

@@ -1,410 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"bytes"
"encoding/binary"
"fmt"
)
const (
maxLPString uint16 = 65535
maxFixedHeaderLength int = 5
maxRemainingLength int32 = 268435455 // bytes, or 256 MB
)
const (
// QoS 0: At most once delivery
// The message is delivered according to the capabilities of the underlying network.
// No response is sent by the receiver and no retry is performed by the sender. The
// message arrives at the receiver either once or not at all.
QosAtMostOnce byte = iota
// QoS 1: At least once delivery
// This quality of service ensures that the message arrives at the receiver at least once.
// A QoS 1 PUBLISH Packet has a Packet Identifier in its variable header and is acknowledged
// by a PUBACK Packet. Section 2.3.1 provides more information about Packet Identifiers.
QosAtLeastOnce
// QoS 2: Exactly once delivery
// This is the highest quality of service, for use when neither loss nor duplication of
// messages are acceptable. There is an increased overhead associated with this quality of
// service.
QosExactlyOnce
// QosFailure is a return value for a subscription if there's a problem while subscribing
// to a specific topic.
QosFailure = 0x80
)
// SupportedVersions is a map of the version number (0x3 or 0x4) to the version string,
// "MQIsdp" for 0x3, and "MQTT" for 0x4.
var SupportedVersions map[byte]string = map[byte]string{
0x3: "MQIsdp",
0x4: "MQTT",
}
// MessageType is the type representing the MQTT packet types. In the MQTT spec,
// MQTT control packet type is represented as a 4-bit unsigned value.
type MessageType byte
// Message is an interface defined for all MQTT message types.
type Message interface {
// Name returns a string representation of the message type. Examples include
// "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of
// the message types and cannot be changed.
Name() string
// Desc returns a string description of the message type. For example, a
// CONNECT message would return "Client request to connect to Server." These
// descriptions are statically defined (copied from the MQTT spec) and cannot
// be changed.
Desc() string
// Type returns the MessageType of the Message. The retured value should be one
// of the constants defined for MessageType.
Type() MessageType
// PacketId returns the packet ID of the Message. The retured value is 0 if
// there's no packet ID for this message type. Otherwise non-0.
PacketId() uint16
// Encode writes the message bytes into the byte array from the argument. It
// returns the number of bytes encoded and whether there's any errors along
// the way. If there's any errors, then the byte slice and count should be
// considered invalid.
Encode([]byte) (int, error)
// Decode reads the bytes in the byte slice from the argument. It returns the
// total number of bytes decoded, and whether there's any errors during the
// process. The byte slice MUST NOT be modified during the duration of this
// message being available since the byte slice is internally stored for
// references.
Decode([]byte) (int, error)
Len() int
}
const (
// RESERVED is a reserved value and should be considered an invalid message type
RESERVED MessageType = iota
// CONNECT: Client to Server. Client request to connect to Server.
CONNECT
// CONNACK: Server to Client. Connect acknowledgement.
CONNACK
// PUBLISH: Client to Server, or Server to Client. Publish message.
PUBLISH
// PUBACK: Client to Server, or Server to Client. Publish acknowledgment for
// QoS 1 messages.
PUBACK
// PUBACK: Client to Server, or Server to Client. Publish received for QoS 2 messages.
// Assured delivery part 1.
PUBREC
// PUBREL: Client to Server, or Server to Client. Publish release for QoS 2 messages.
// Assured delivery part 1.
PUBREL
// PUBCOMP: Client to Server, or Server to Client. Publish complete for QoS 2 messages.
// Assured delivery part 3.
PUBCOMP
// SUBSCRIBE: Client to Server. Client subscribe request.
SUBSCRIBE
// SUBACK: Server to Client. Subscribe acknowledgement.
SUBACK
// UNSUBSCRIBE: Client to Server. Unsubscribe request.
UNSUBSCRIBE
// UNSUBACK: Server to Client. Unsubscribe acknowlegment.
UNSUBACK
// PINGREQ: Client to Server. PING request.
PINGREQ
// PINGRESP: Server to Client. PING response.
PINGRESP
// DISCONNECT: Client to Server. Client is disconnecting.
DISCONNECT
// RESERVED2 is a reserved value and should be considered an invalid message type.
RESERVED2
)
func (this MessageType) String() string {
return this.Name()
}
// Name returns the name of the message type. It should correspond to one of the
// constant values defined for MessageType. It is statically defined and cannot
// be changed.
func (this MessageType) Name() string {
switch this {
case RESERVED:
return "RESERVED"
case CONNECT:
return "CONNECT"
case CONNACK:
return "CONNACK"
case PUBLISH:
return "PUBLISH"
case PUBACK:
return "PUBACK"
case PUBREC:
return "PUBREC"
case PUBREL:
return "PUBREL"
case PUBCOMP:
return "PUBCOMP"
case SUBSCRIBE:
return "SUBSCRIBE"
case SUBACK:
return "SUBACK"
case UNSUBSCRIBE:
return "UNSUBSCRIBE"
case UNSUBACK:
return "UNSUBACK"
case PINGREQ:
return "PINGREQ"
case PINGRESP:
return "PINGRESP"
case DISCONNECT:
return "DISCONNECT"
case RESERVED2:
return "RESERVED2"
}
return "UNKNOWN"
}
// Desc returns the description of the message type. It is statically defined (copied
// from MQTT spec) and cannot be changed.
func (this MessageType) Desc() string {
switch this {
case RESERVED:
return "Reserved"
case CONNECT:
return "Client request to connect to Server"
case CONNACK:
return "Connect acknowledgement"
case PUBLISH:
return "Publish message"
case PUBACK:
return "Publish acknowledgement"
case PUBREC:
return "Publish received (assured delivery part 1)"
case PUBREL:
return "Publish release (assured delivery part 2)"
case PUBCOMP:
return "Publish complete (assured delivery part 3)"
case SUBSCRIBE:
return "Client subscribe request"
case SUBACK:
return "Subscribe acknowledgement"
case UNSUBSCRIBE:
return "Unsubscribe request"
case UNSUBACK:
return "Unsubscribe acknowledgement"
case PINGREQ:
return "PING request"
case PINGRESP:
return "PING response"
case DISCONNECT:
return "Client is disconnecting"
case RESERVED2:
return "Reserved"
}
return "UNKNOWN"
}
// DefaultFlags returns the default flag values for the message type, as defined by
// the MQTT spec.
func (this MessageType) DefaultFlags() byte {
switch this {
case RESERVED:
return 0
case CONNECT:
return 0
case CONNACK:
return 0
case PUBLISH:
return 0
case PUBACK:
return 0
case PUBREC:
return 0
case PUBREL:
return 2
case PUBCOMP:
return 0
case SUBSCRIBE:
return 2
case SUBACK:
return 0
case UNSUBSCRIBE:
return 2
case UNSUBACK:
return 0
case PINGREQ:
return 0
case PINGRESP:
return 0
case DISCONNECT:
return 0
case RESERVED2:
return 0
}
return 0
}
// New creates a new message based on the message type. It is a shortcut to call
// one of the New*Message functions. If an error is returned then the message type
// is invalid.
func (this MessageType) New() (Message, error) {
switch this {
case CONNECT:
return NewConnectMessage(), nil
case CONNACK:
return NewConnackMessage(), nil
case PUBLISH:
return NewPublishMessage(), nil
case PUBACK:
return NewPubackMessage(), nil
case PUBREC:
return NewPubrecMessage(), nil
case PUBREL:
return NewPubrelMessage(), nil
case PUBCOMP:
return NewPubcompMessage(), nil
case SUBSCRIBE:
return NewSubscribeMessage(), nil
case SUBACK:
return NewSubackMessage(), nil
case UNSUBSCRIBE:
return NewUnsubscribeMessage(), nil
case UNSUBACK:
return NewUnsubackMessage(), nil
case PINGREQ:
return NewPingreqMessage(), nil
case PINGRESP:
return NewPingrespMessage(), nil
case DISCONNECT:
return NewDisconnectMessage(), nil
}
return nil, fmt.Errorf("msgtype/NewMessage: Invalid message type %d", this)
}
// Valid returns a boolean indicating whether the message type is valid or not.
func (this MessageType) Valid() bool {
return this > RESERVED && this < RESERVED2
}
// ValidTopic checks the topic, which is a slice of bytes, to see if it's valid. Topic is
// considered valid if it's longer than 0 bytes, and doesn't contain any wildcard characters
// such as + and #.
func ValidTopic(topic []byte) bool {
return len(topic) > 0 && bytes.IndexByte(topic, '#') == -1 && bytes.IndexByte(topic, '+') == -1
}
// ValidQos checks the QoS value to see if it's valid. Valid QoS are QosAtMostOnce,
// QosAtLeastonce, and QosExactlyOnce.
func ValidQos(qos byte) bool {
return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce
}
// ValidVersion checks to see if the version is valid. Current supported versions include 0x3 and 0x4.
func ValidVersion(v byte) bool {
_, ok := SupportedVersions[v]
return ok
}
// ValidConnackError checks to see if the error is a Connack Error or not
func ValidConnackError(err error) bool {
return err == ErrInvalidProtocolVersion || err == ErrIdentifierRejected ||
err == ErrServerUnavailable || err == ErrBadUsernameOrPassword || err == ErrNotAuthorized
}
func ValidConnackErrorEx(err error) (bool, ConnackCode) {
if err == ErrInvalidProtocolVersion {
return true, ErrInvalidProtocolVersion
}
if err == ErrIdentifierRejected {
return true, ErrIdentifierRejected
}
if err == ErrServerUnavailable {
return true, ErrServerUnavailable
}
if err == ErrBadUsernameOrPassword {
return true, ErrBadUsernameOrPassword
}
if err == ErrNotAuthorized {
return true, ErrNotAuthorized
}
return false, ConnectionAccepted
}
// Read length prefixed bytes
func readLPBytes(buf []byte) ([]byte, int, error) {
if len(buf) < 2 {
return nil, 0, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2, len(buf))
}
n, total := 0, 0
n = int(binary.BigEndian.Uint16(buf))
total += 2
if len(buf) < n {
return nil, total, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", n, len(buf))
}
total += n
return buf[2:total], total, nil
}
// Write length prefixed bytes
func writeLPBytes(buf []byte, b []byte) (int, error) {
total, n := 0, len(b)
if n > int(maxLPString) {
return 0, fmt.Errorf("utils/writeLPBytes: Length (%d) greater than %d bytes.", n, maxLPString)
}
if len(buf) < 2+n {
return 0, fmt.Errorf("utils/writeLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2+n, len(buf))
}
binary.BigEndian.PutUint16(buf, uint16(n))
total += 2
copy(buf[total:], b)
total += n
return total, nil
}

View File

@@ -1,178 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
var (
lpstrings []string = []string{
"this is a test",
"hope it succeeds",
"but just in case",
"send me your millions",
"",
}
lpstringBytes []byte = []byte{
0x0, 0xe, 't', 'h', 'i', 's', ' ', 'i', 's', ' ', 'a', ' ', 't', 'e', 's', 't',
0x0, 0x10, 'h', 'o', 'p', 'e', ' ', 'i', 't', ' ', 's', 'u', 'c', 'c', 'e', 'e', 'd', 's',
0x0, 0x10, 'b', 'u', 't', ' ', 'j', 'u', 's', 't', ' ', 'i', 'n', ' ', 'c', 'a', 's', 'e',
0x0, 0x15, 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'y', 'o', 'u', 'r', ' ', 'm', 'i', 'l', 'l', 'i', 'o', 'n', 's',
0x0, 0x0,
}
msgBytes []byte = []byte{
byte(CONNECT << 4),
60,
0, // Length MSB (0)
4, // Length LSB (4)
'M', 'Q', 'T', 'T',
4, // Protocol level 4
206, // connect flags 11001110, will QoS = 01
0, // Keep Alive MSB (0)
10, // Keep Alive LSB (10)
0, // Client ID MSB (0)
7, // Client ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Will Topic MSB (0)
4, // Will Topic LSB (4)
'w', 'i', 'l', 'l',
0, // Will Message MSB (0)
12, // Will Message LSB (12)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
0, // Username ID MSB (0)
7, // Username ID LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // Password ID MSB (0)
10, // Password ID LSB (10)
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
}
)
func TestReadLPBytes(t *testing.T) {
total := 0
for _, str := range lpstrings {
b, n, err := readLPBytes(lpstringBytes[total:])
require.NoError(t, err)
require.Equal(t, str, string(b))
require.Equal(t, len(str)+2, n)
total += n
}
}
func TestWriteLPBytes(t *testing.T) {
total := 0
buf := make([]byte, 1000)
for _, str := range lpstrings {
n, err := writeLPBytes(buf[total:], []byte(str))
require.NoError(t, err)
require.Equal(t, 2+len(str), n)
total += n
}
require.Equal(t, lpstringBytes, buf[:total])
}
func TestMessageTypes(t *testing.T) {
if CONNECT != 1 ||
CONNACK != 2 ||
PUBLISH != 3 ||
PUBACK != 4 ||
PUBREC != 5 ||
PUBREL != 6 ||
PUBCOMP != 7 ||
SUBSCRIBE != 8 ||
SUBACK != 9 ||
UNSUBSCRIBE != 10 ||
UNSUBACK != 11 ||
PINGREQ != 12 ||
PINGRESP != 13 ||
DISCONNECT != 14 {
t.Errorf("Message types have invalid code")
}
}
func TestQosCodes(t *testing.T) {
if QosAtMostOnce != 0 || QosAtLeastOnce != 1 || QosExactlyOnce != 2 {
t.Errorf("QOS codes invalid")
}
}
func TestConnackReturnCodes(t *testing.T) {
require.Equal(t, ErrInvalidProtocolVersion.Error(), ConnackCode(1).Error(), "Incorrect ConnackCode error value.")
require.Equal(t, ErrIdentifierRejected.Error(), ConnackCode(2).Error(), "Incorrect ConnackCode error value.")
require.Equal(t, ErrServerUnavailable.Error(), ConnackCode(3).Error(), "Incorrect ConnackCode error value.")
require.Equal(t, ErrBadUsernameOrPassword.Error(), ConnackCode(4).Error(), "Incorrect ConnackCode error value.")
require.Equal(t, ErrNotAuthorized.Error(), ConnackCode(5).Error(), "Incorrect ConnackCode error value.")
}
func TestFixedHeaderFlags(t *testing.T) {
type detail struct {
name string
flags byte
}
details := map[MessageType]detail{
RESERVED: detail{"RESERVED", 0},
CONNECT: detail{"CONNECT", 0},
CONNACK: detail{"CONNACK", 0},
PUBLISH: detail{"PUBLISH", 0},
PUBACK: detail{"PUBACK", 0},
PUBREC: detail{"PUBREC", 0},
PUBREL: detail{"PUBREL", 2},
PUBCOMP: detail{"PUBCOMP", 0},
SUBSCRIBE: detail{"SUBSCRIBE", 2},
SUBACK: detail{"SUBACK", 0},
UNSUBSCRIBE: detail{"UNSUBSCRIBE", 2},
UNSUBACK: detail{"UNSUBACK", 0},
PINGREQ: detail{"PINGREQ", 0},
PINGRESP: detail{"PINGRESP", 0},
DISCONNECT: detail{"DISCONNECT", 0},
RESERVED2: detail{"RESERVED2", 0},
}
for m, d := range details {
if m.Name() != d.name {
t.Errorf("Name mismatch. Expecting %s, got %s.", d.name, m.Name())
}
if m.DefaultFlags() != d.flags {
t.Errorf("Flag mismatch for %s. Expecting %d, got %d.", m.Name(), d.flags, m.DefaultFlags())
}
}
}
func TestSupportedVersions(t *testing.T) {
for k, v := range SupportedVersions {
if k == 0x03 && v != "MQIsdp" {
t.Errorf("Protocol version and name mismatch. Expect %s, got %s.", "MQIsdp", v)
}
}
}

View File

@@ -1,135 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPingreqMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(PINGREQ << 4),
0,
}
msg := NewPingreqMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, PINGREQ, msg.Type(), "Error decoding message.")
}
func TestPingreqMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(PINGREQ << 4),
0,
}
msg := NewPingreqMessage()
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
func TestPingrespMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(PINGRESP << 4),
0,
}
msg := NewPingrespMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, PINGRESP, msg.Type(), "Error decoding message.")
}
func TestPingrespMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(PINGRESP << 4),
0,
}
msg := NewPingrespMessage()
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestPingreqDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(PINGREQ << 4),
0,
}
msg := NewPingreqMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestPingrespDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(PINGRESP << 4),
0,
}
msg := NewPingrespMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,34 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
// The PINGREQ Packet is sent from a Client to the Server. It can be used to:
// 1. Indicate to the Server that the Client is alive in the absence of any other
// Control Packets being sent from the Client to the Server.
// 2. Request that the Server responds to confirm that it is alive.
// 3. Exercise the network to indicate that the Network Connection is active.
type PingreqMessage struct {
DisconnectMessage
}
var _ Message = (*PingreqMessage)(nil)
// NewPingreqMessage creates a new PINGREQ message.
func NewPingreqMessage() *PingreqMessage {
msg := &PingreqMessage{}
msg.SetType(PINGREQ)
return msg
}

View File

@@ -1,31 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
// A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ
// Packet. It indicates that the Server is alive.
type PingrespMessage struct {
DisconnectMessage
}
var _ Message = (*PingrespMessage)(nil)
// NewPingrespMessage creates a new PINGRESP message.
func NewPingrespMessage() *PingrespMessage {
msg := &PingrespMessage{}
msg.SetType(PINGRESP)
return msg
}

View File

@@ -1,109 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import "fmt"
// A PUBACK Packet is the response to a PUBLISH Packet with QoS level 1.
type PubackMessage struct {
header
}
var _ Message = (*PubackMessage)(nil)
// NewPubackMessage creates a new PUBACK message.
func NewPubackMessage() *PubackMessage {
msg := &PubackMessage{}
msg.SetType(PUBACK)
return msg
}
func (this PubackMessage) String() string {
return fmt.Sprintf("%s, Packet ID=%d", this.header, this.packetId)
}
func (this *PubackMessage) Len() int {
if !this.dirty {
return len(this.dbuf)
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0
}
return this.header.msglen() + ml
}
func (this *PubackMessage) Decode(src []byte) (int, error) {
total := 0
n, err := this.header.decode(src[total:])
total += n
if err != nil {
return total, err
}
//this.packetId = binary.BigEndian.Uint16(src[total:])
this.packetId = src[total : total+2]
total += 2
this.dirty = false
return total, nil
}
func (this *PubackMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("puback/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
hl := this.header.msglen()
ml := this.msglen()
if len(dst) < hl+ml {
return 0, fmt.Errorf("puback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
}
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0, err
}
total := 0
n, err := this.header.encode(dst[total:])
total += n
if err != nil {
return total, err
}
if copy(dst[total:total+2], this.packetId) != 2 {
dst[total], dst[total+1] = 0, 0
}
total += 2
return total, nil
}
func (this *PubackMessage) msglen() int {
// packet ID
return 2
}

View File

@@ -1,108 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPubackMessageFields(t *testing.T) {
msg := NewPubackMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()))
}
func TestPubackMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(PUBACK << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, PUBACK, msg.Type(), "Error decoding message.")
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
}
// test insufficient bytes
func TestPubackMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(PUBACK << 4),
2,
7, // packet ID LSB (7)
}
msg := NewPubackMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestPubackMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(PUBACK << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubackMessage()
msg.SetPacketId(7)
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestPubackDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(PUBACK << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,31 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
// The PUBCOMP Packet is the response to a PUBREL Packet. It is the fourth and
// final packet of the QoS 2 protocol exchange.
type PubcompMessage struct {
PubackMessage
}
var _ Message = (*PubcompMessage)(nil)
// NewPubcompMessage creates a new PUBCOMP message.
func NewPubcompMessage() *PubcompMessage {
msg := &PubcompMessage{}
msg.SetType(PUBCOMP)
return msg
}

View File

@@ -1,108 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPubcompMessageFields(t *testing.T) {
msg := NewPubcompMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()))
}
func TestPubcompMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(PUBCOMP << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubcompMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, PUBCOMP, msg.Type(), "Error decoding message.")
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
}
// test insufficient bytes
func TestPubcompMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(PUBCOMP << 4),
2,
7, // packet ID LSB (7)
}
msg := NewPubcompMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestPubcompMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(PUBCOMP << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubcompMessage()
msg.SetPacketId(7)
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestPubcompDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(PUBCOMP << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubcompMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,250 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"fmt"
"sync/atomic"
)
// A PUBLISH Control Packet is sent from a Client to a Server or from Server to a Client
// to transport an Application Message.
type PublishMessage struct {
header
topic []byte
payload []byte
}
var _ Message = (*PublishMessage)(nil)
// NewPublishMessage creates a new PUBLISH message.
func NewPublishMessage() *PublishMessage {
msg := &PublishMessage{}
msg.SetType(PUBLISH)
return msg
}
func (this PublishMessage) String() string {
return fmt.Sprintf("%s, Topic=%q, Packet ID=%d, QoS=%d, Retained=%t, Dup=%t, Payload=%v",
this.header, this.topic, this.packetId, this.QoS(), this.Retain(), this.Dup(), this.payload)
}
// Dup returns the value specifying the duplicate delivery of a PUBLISH Control Packet.
// If the DUP flag is set to 0, it indicates that this is the first occasion that the
// Client or Server has attempted to send this MQTT PUBLISH Packet. If the DUP flag is
// set to 1, it indicates that this might be re-delivery of an earlier attempt to send
// the Packet.
func (this *PublishMessage) Dup() bool {
return ((this.Flags() >> 3) & 0x1) == 1
}
// SetDup sets the value specifying the duplicate delivery of a PUBLISH Control Packet.
func (this *PublishMessage) SetDup(v bool) {
if v {
this.mtypeflags[0] |= 0x8 // 00001000
} else {
this.mtypeflags[0] &= 247 // 11110111
}
}
// Retain returns the value of the RETAIN flag. This flag is only used on the PUBLISH
// Packet. If the RETAIN flag is set to 1, in a PUBLISH Packet sent by a Client to a
// Server, the Server MUST store the Application Message and its QoS, so that it can be
// delivered to future subscribers whose subscriptions match its topic name.
func (this *PublishMessage) Retain() bool {
return (this.Flags() & 0x1) == 1
}
// SetRetain sets the value of the RETAIN flag.
func (this *PublishMessage) SetRetain(v bool) {
if v {
this.mtypeflags[0] |= 0x1 // 00000001
} else {
this.mtypeflags[0] &= 254 // 11111110
}
}
// QoS returns the field that indicates the level of assurance for delivery of an
// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce.
func (this *PublishMessage) QoS() byte {
return (this.Flags() >> 1) & 0x3
}
// SetQoS sets the field that indicates the level of assurance for delivery of an
// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce.
// An error is returned if the value is not one of these.
func (this *PublishMessage) SetQoS(v byte) error {
if v != 0x0 && v != 0x1 && v != 0x2 {
return fmt.Errorf("publish/SetQoS: Invalid QoS %d.", v)
}
this.mtypeflags[0] = (this.mtypeflags[0] & 249) | (v << 1) // 249 = 11111001
return nil
}
// Topic returns the the topic name that identifies the information channel to which
// payload data is published.
func (this *PublishMessage) Topic() []byte {
return this.topic
}
// SetTopic sets the the topic name that identifies the information channel to which
// payload data is published. An error is returned if ValidTopic() is falbase.
func (this *PublishMessage) SetTopic(v []byte) error {
if !ValidTopic(v) {
return fmt.Errorf("publish/SetTopic: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(v))
}
this.topic = v
this.dirty = true
return nil
}
// Payload returns the application message that's part of the PUBLISH message.
func (this *PublishMessage) Payload() []byte {
return this.payload
}
// SetPayload sets the application message that's part of the PUBLISH message.
func (this *PublishMessage) SetPayload(v []byte) {
this.payload = v
this.dirty = true
}
func (this *PublishMessage) Len() int {
if !this.dirty {
return len(this.dbuf)
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0
}
return this.header.msglen() + ml
}
func (this *PublishMessage) Decode(src []byte) (int, error) {
total := 0
hn, err := this.header.decode(src[total:])
total += hn
if err != nil {
return total, err
}
n := 0
this.topic, n, err = readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
if !ValidTopic(this.topic) {
return total, fmt.Errorf("publish/Decode: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(this.topic))
}
// The packet identifier field is only present in the PUBLISH packets where the
// QoS level is 1 or 2
if this.QoS() != 0 {
//this.packetId = binary.BigEndian.Uint16(src[total:])
this.packetId = src[total : total+2]
total += 2
}
l := int(this.remlen) - (total - hn)
this.payload = src[total : total+l]
total += len(this.payload)
this.dirty = false
return total, nil
}
func (this *PublishMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("publish/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
if len(this.topic) == 0 {
return 0, fmt.Errorf("publish/Encode: Topic name is empty.")
}
if len(this.payload) == 0 {
return 0, fmt.Errorf("publish/Encode: Payload is empty.")
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0, err
}
hl := this.header.msglen()
if len(dst) < hl+ml {
return 0, fmt.Errorf("publish/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
}
total := 0
n, err := this.header.encode(dst[total:])
total += n
if err != nil {
return total, err
}
n, err = writeLPBytes(dst[total:], this.topic)
total += n
if err != nil {
return total, err
}
// The packet identifier field is only present in the PUBLISH packets where the QoS level is 1 or 2
if this.QoS() != 0 {
if this.PacketId() == 0 {
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
}
n = copy(dst[total:], this.packetId)
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
total += n
}
copy(dst[total:], this.payload)
total += len(this.payload)
return total, nil
}
func (this *PublishMessage) msglen() int {
total := 2 + len(this.topic) + len(this.payload)
if this.QoS() != 0 {
total += 2
}
return total
}

View File

@@ -1,281 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPublishMessageHeaderFields(t *testing.T) {
msg := NewPublishMessage()
msg.mtypeflags[0] |= 11
require.True(t, msg.Dup(), "Incorrect DUP flag.")
require.True(t, msg.Retain(), "Incorrect RETAIN flag.")
require.Equal(t, 1, int(msg.QoS()), "Incorrect QoS.")
msg.SetDup(false)
require.False(t, msg.Dup(), "Incorrect DUP flag.")
msg.SetRetain(false)
require.False(t, msg.Retain(), "Incorrect RETAIN flag.")
err := msg.SetQoS(2)
require.NoError(t, err, "Error setting QoS.")
require.Equal(t, 2, int(msg.QoS()), "Incorrect QoS.")
err = msg.SetQoS(3)
require.Error(t, err)
err = msg.SetQoS(0)
require.NoError(t, err, "Error setting QoS.")
require.Equal(t, 0, int(msg.QoS()), "Incorrect QoS.")
msg.SetDup(true)
require.True(t, msg.Dup(), "Incorrect DUP flag.")
msg.SetRetain(true)
require.True(t, msg.Retain(), "Incorrect RETAIN flag.")
}
func TestPublishMessageFields(t *testing.T) {
msg := NewPublishMessage()
msg.SetTopic([]byte("coolstuff"))
require.Equal(t, "coolstuff", string(msg.Topic()), "Error setting message topic.")
err := msg.SetTopic([]byte("coolstuff/#"))
require.Error(t, err)
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()), "Error setting acket ID.")
msg.SetPayload([]byte("this is a payload to be sent"))
require.Equal(t, []byte("this is a payload to be sent"), msg.Payload(), "Error setting payload.")
}
func TestPublishMessageDecode1(t *testing.T) {
msgBytes := []byte{
byte(PUBLISH<<4) | 2,
23,
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // packet ID MSB (0)
7, // packet ID LSB (7)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
}
msg := NewPublishMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
require.Equal(t, "surgemq", string(msg.Topic()), "Error deocding topic name.")
require.Equal(t, []byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'}, msg.Payload(), "Error deocding payload.")
}
// test insufficient bytes
func TestPublishMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(PUBLISH<<4) | 2,
26,
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // packet ID MSB (0)
7, // packet ID LSB (7)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
}
msg := NewPublishMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
// test qos = 0 and no client id
func TestPublishMessageDecode3(t *testing.T) {
msgBytes := []byte{
byte(PUBLISH << 4),
21,
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
}
msg := NewPublishMessage()
_, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
}
func TestPublishMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(PUBLISH<<4) | 2,
23,
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // packet ID MSB (0)
7, // packet ID LSB (7)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
}
msg := NewPublishMessage()
msg.SetTopic([]byte("surgemq"))
msg.SetQoS(1)
msg.SetPacketId(7)
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
dst := make([]byte, 100)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test empty topic name
func TestPublishMessageEncode2(t *testing.T) {
msg := NewPublishMessage()
msg.SetTopic([]byte(""))
msg.SetPacketId(7)
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
dst := make([]byte, 100)
_, err := msg.Encode(dst)
require.Error(t, err)
}
// test encoding qos = 0 and no packet id
func TestPublishMessageEncode3(t *testing.T) {
msgBytes := []byte{
byte(PUBLISH << 4),
21,
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
}
msg := NewPublishMessage()
msg.SetTopic([]byte("surgemq"))
msg.SetQoS(0)
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
dst := make([]byte, 100)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test large message
func TestPublishMessageEncode4(t *testing.T) {
msgBytes := []byte{
byte(PUBLISH << 4),
137,
8,
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
}
payload := make([]byte, 1024)
msgBytes = append(msgBytes, payload...)
msg := NewPublishMessage()
msg.SetTopic([]byte("surgemq"))
msg.SetQoS(0)
msg.SetPayload(payload)
require.Equal(t, len(msgBytes), msg.Len())
dst := make([]byte, 1100)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test from github issue #2, @mrdg
func TestPublishDecodeEncodeEquiv2(t *testing.T) {
msgBytes := []byte{50, 18, 0, 9, 103, 114, 101, 101, 116, 105, 110, 103, 115, 0, 1, 72, 101, 108, 108, 111}
msg := NewPublishMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestPublishDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(PUBLISH<<4) | 2,
23,
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // packet ID MSB (0)
7, // packet ID LSB (7)
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
}
msg := NewPublishMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,31 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
type PubrecMessage struct {
PubackMessage
}
// A PUBREC Packet is the response to a PUBLISH Packet with QoS 2. It is the second
// packet of the QoS 2 protocol exchange.
var _ Message = (*PubrecMessage)(nil)
// NewPubrecMessage creates a new PUBREC message.
func NewPubrecMessage() *PubrecMessage {
msg := &PubrecMessage{}
msg.SetType(PUBREC)
return msg
}

View File

@@ -1,108 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPubrecMessageFields(t *testing.T) {
msg := NewPubrecMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()))
}
func TestPubrecMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(PUBREC << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubrecMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, PUBREC, msg.Type(), "Error decoding message.")
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
}
// test insufficient bytes
func TestPubrecMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(PUBREC << 4),
2,
7, // packet ID LSB (7)
}
msg := NewPubrecMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestPubrecMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(PUBREC << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubrecMessage()
msg.SetPacketId(7)
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestPubrecDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(PUBREC << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubrecMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,31 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
// A PUBREL Packet is the response to a PUBREC Packet. It is the third packet of the
// QoS 2 protocol exchange.
type PubrelMessage struct {
PubackMessage
}
var _ Message = (*PubrelMessage)(nil)
// NewPubrelMessage creates a new PUBREL message.
func NewPubrelMessage() *PubrelMessage {
msg := &PubrelMessage{}
msg.SetType(PUBREL)
return msg
}

View File

@@ -1,108 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestPubrelMessageFields(t *testing.T) {
msg := NewPubrelMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()))
}
func TestPubrelMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(PUBREL<<4) | 2,
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubrelMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, PUBREL, msg.Type(), "Error decoding message.")
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
}
// test insufficient bytes
func TestPubrelMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(PUBREL<<4) | 2,
2,
7, // packet ID LSB (7)
}
msg := NewPubrelMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestPubrelMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(PUBREL<<4) | 2,
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubrelMessage()
msg.SetPacketId(7)
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestPubrelDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(PUBREL<<4) | 2,
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewPubrelMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,160 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import "fmt"
// A SUBACK Packet is sent by the Server to the Client to confirm receipt and processing
// of a SUBSCRIBE Packet.
//
// A SUBACK Packet contains a list of return codes, that specify the maximum QoS level
// that was granted in each Subscription that was requested by the SUBSCRIBE.
type SubackMessage struct {
header
returnCodes []byte
}
var _ Message = (*SubackMessage)(nil)
// NewSubackMessage creates a new SUBACK message.
func NewSubackMessage() *SubackMessage {
msg := &SubackMessage{}
msg.SetType(SUBACK)
return msg
}
// String returns a string representation of the message.
func (this SubackMessage) String() string {
return fmt.Sprintf("%s, Packet ID=%d, Return Codes=%v", this.header, this.PacketId(), this.returnCodes)
}
// ReturnCodes returns the list of QoS returns from the subscriptions sent in the SUBSCRIBE message.
func (this *SubackMessage) ReturnCodes() []byte {
return this.returnCodes
}
// AddReturnCodes sets the list of QoS returns from the subscriptions sent in the SUBSCRIBE message.
// An error is returned if any of the QoS values are not valid.
func (this *SubackMessage) AddReturnCodes(ret []byte) error {
for _, c := range ret {
if c != QosAtMostOnce && c != QosAtLeastOnce && c != QosExactlyOnce && c != QosFailure {
return fmt.Errorf("suback/AddReturnCode: Invalid return code %d. Must be 0, 1, 2, 0x80.", c)
}
this.returnCodes = append(this.returnCodes, c)
}
this.dirty = true
return nil
}
// AddReturnCode adds a single QoS return value.
func (this *SubackMessage) AddReturnCode(ret byte) error {
return this.AddReturnCodes([]byte{ret})
}
func (this *SubackMessage) Len() int {
if !this.dirty {
return len(this.dbuf)
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0
}
return this.header.msglen() + ml
}
func (this *SubackMessage) Decode(src []byte) (int, error) {
total := 0
hn, err := this.header.decode(src[total:])
total += hn
if err != nil {
return total, err
}
//this.packetId = binary.BigEndian.Uint16(src[total:])
this.packetId = src[total : total+2]
total += 2
l := int(this.remlen) - (total - hn)
this.returnCodes = src[total : total+l]
total += len(this.returnCodes)
for i, code := range this.returnCodes {
if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 {
return total, fmt.Errorf("suback/Decode: Invalid return code %d for topic %d", code, i)
}
}
this.dirty = false
return total, nil
}
func (this *SubackMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("suback/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
for i, code := range this.returnCodes {
if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 {
return 0, fmt.Errorf("suback/Encode: Invalid return code %d for topic %d", code, i)
}
}
hl := this.header.msglen()
ml := this.msglen()
if len(dst) < hl+ml {
return 0, fmt.Errorf("suback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
}
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0, err
}
total := 0
n, err := this.header.encode(dst[total:])
total += n
if err != nil {
return total, err
}
if copy(dst[total:total+2], this.packetId) != 2 {
dst[total], dst[total+1] = 0, 0
}
total += 2
copy(dst[total:], this.returnCodes)
total += len(this.returnCodes)
return total, nil
}
func (this *SubackMessage) msglen() int {
return 2 + len(this.returnCodes)
}

View File

@@ -1,134 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSubackMessageFields(t *testing.T) {
msg := NewSubackMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
msg.AddReturnCode(1)
require.Equal(t, 1, len(msg.ReturnCodes()), "Error adding return code.")
err := msg.AddReturnCode(0x90)
require.Error(t, err)
}
func TestSubackMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(SUBACK << 4),
6,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // return code 1
1, // return code 2
2, // return code 3
0x80, // return code 4
}
msg := NewSubackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, SUBACK, msg.Type(), "Error decoding message.")
require.Equal(t, 4, len(msg.ReturnCodes()), "Error adding return code.")
}
// test with wrong return code
func TestSubackMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(SUBACK << 4),
6,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // return code 1
1, // return code 2
2, // return code 3
0x81, // return code 4
}
msg := NewSubackMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestSubackMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(SUBACK << 4),
6,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // return code 1
1, // return code 2
2, // return code 3
0x80, // return code 4
}
msg := NewSubackMessage()
msg.SetPacketId(7)
msg.AddReturnCode(0)
msg.AddReturnCode(1)
msg.AddReturnCode(2)
msg.AddReturnCode(0x80)
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestSubackDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(SUBACK << 4),
6,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // return code 1
1, // return code 2
2, // return code 3
0x80, // return code 4
}
msg := NewSubackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,253 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"bytes"
"fmt"
"sync/atomic"
)
// The SUBSCRIBE Packet is sent from the Client to the Server to create one or more
// Subscriptions. Each Subscription registers a Clients interest in one or more
// Topics. The Server sends PUBLISH Packets to the Client in order to forward
// Application Messages that were published to Topics that match these Subscriptions.
// The SUBSCRIBE Packet also specifies (for each Subscription) the maximum QoS with
// which the Server can send Application Messages to the Client.
type SubscribeMessage struct {
header
topics [][]byte
qos []byte
}
var _ Message = (*SubscribeMessage)(nil)
// NewSubscribeMessage creates a new SUBSCRIBE message.
func NewSubscribeMessage() *SubscribeMessage {
msg := &SubscribeMessage{}
msg.SetType(SUBSCRIBE)
return msg
}
func (this SubscribeMessage) String() string {
msgstr := fmt.Sprintf("%s, Packet ID=%d", this.header, this.PacketId())
for i, t := range this.topics {
msgstr = fmt.Sprintf("%s, Topic[%d]=%q/%d", msgstr, i, string(t), this.qos[i])
}
return msgstr
}
// Topics returns a list of topics sent by the Client.
func (this *SubscribeMessage) Topics() [][]byte {
return this.topics
}
// AddTopic adds a single topic to the message, along with the corresponding QoS.
// An error is returned if QoS is invalid.
func (this *SubscribeMessage) AddTopic(topic []byte, qos byte) error {
if !ValidQos(qos) {
return fmt.Errorf("Invalid QoS %d", qos)
}
var i int
var t []byte
var found bool
for i, t = range this.topics {
if bytes.Equal(t, topic) {
found = true
break
}
}
if found {
this.qos[i] = qos
return nil
}
this.topics = append(this.topics, topic)
this.qos = append(this.qos, qos)
this.dirty = true
return nil
}
// RemoveTopic removes a single topic from the list of existing ones in the message.
// If topic does not exist it just does nothing.
func (this *SubscribeMessage) RemoveTopic(topic []byte) {
var i int
var t []byte
var found bool
for i, t = range this.topics {
if bytes.Equal(t, topic) {
found = true
break
}
}
if found {
this.topics = append(this.topics[:i], this.topics[i+1:]...)
this.qos = append(this.qos[:i], this.qos[i+1:]...)
}
this.dirty = true
}
// TopicExists checks to see if a topic exists in the list.
func (this *SubscribeMessage) TopicExists(topic []byte) bool {
for _, t := range this.topics {
if bytes.Equal(t, topic) {
return true
}
}
return false
}
// TopicQos returns the QoS level of a topic. If topic does not exist, QosFailure
// is returned.
func (this *SubscribeMessage) TopicQos(topic []byte) byte {
for i, t := range this.topics {
if bytes.Equal(t, topic) {
return this.qos[i]
}
}
return QosFailure
}
// Qos returns the list of QoS current in the message.
func (this *SubscribeMessage) Qos() []byte {
return this.qos
}
func (this *SubscribeMessage) Len() int {
if !this.dirty {
return len(this.dbuf)
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0
}
return this.header.msglen() + ml
}
func (this *SubscribeMessage) Decode(src []byte) (int, error) {
total := 0
hn, err := this.header.decode(src[total:])
total += hn
if err != nil {
return total, err
}
//this.packetId = binary.BigEndian.Uint16(src[total:])
this.packetId = src[total : total+2]
total += 2
remlen := int(this.remlen) - (total - hn)
for remlen > 0 {
t, n, err := readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
this.topics = append(this.topics, t)
this.qos = append(this.qos, src[total])
total++
remlen = remlen - n - 1
}
if len(this.topics) == 0 {
return 0, fmt.Errorf("subscribe/Decode: Empty topic list")
}
this.dirty = false
return total, nil
}
func (this *SubscribeMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("subscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
hl := this.header.msglen()
ml := this.msglen()
if len(dst) < hl+ml {
return 0, fmt.Errorf("subscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
}
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0, err
}
total := 0
n, err := this.header.encode(dst[total:])
total += n
if err != nil {
return total, err
}
if this.PacketId() == 0 {
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
}
n = copy(dst[total:], this.packetId)
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
total += n
for i, t := range this.topics {
n, err := writeLPBytes(dst[total:], t)
total += n
if err != nil {
return total, err
}
dst[total] = this.qos[i]
total++
}
return total, nil
}
func (this *SubscribeMessage) msglen() int {
// packet ID
total := 2
for _, t := range this.topics {
total += 2 + len(t) + 1
}
return total
}

View File

@@ -1,161 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestSubscribeMessageFields(t *testing.T) {
msg := NewSubscribeMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
msg.AddTopic([]byte("/a/b/#/c"), 1)
require.Equal(t, 1, len(msg.Topics()), "Error adding topic.")
require.False(t, msg.TopicExists([]byte("a/b")), "Topic should not exist.")
msg.RemoveTopic([]byte("/a/b/#/c"))
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
}
func TestSubscribeMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(SUBSCRIBE<<4) | 2,
36,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // QoS
0, // topic name MSB (0)
8, // topic name LSB (8)
'/', 'a', '/', 'b', '/', '#', '/', 'c',
1, // QoS
0, // topic name MSB (0)
10, // topic name LSB (10)
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
2, // QoS
}
msg := NewSubscribeMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, SUBSCRIBE, msg.Type(), "Error decoding message.")
require.Equal(t, 3, len(msg.Topics()), "Error decoding topics.")
require.True(t, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.")
require.Equal(t, 0, int(msg.TopicQos([]byte("surgemq"))), "Incorrect topic qos.")
require.True(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.")
require.Equal(t, 1, int(msg.TopicQos([]byte("/a/b/#/c"))), "Incorrect topic qos.")
require.True(t, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.")
require.Equal(t, 2, int(msg.TopicQos([]byte("/a/b/#/cdd"))), "Incorrect topic qos.")
}
// test empty topic list
func TestSubscribeMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(SUBSCRIBE<<4) | 2,
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewSubscribeMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestSubscribeMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(SUBSCRIBE<<4) | 2,
36,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // QoS
0, // topic name MSB (0)
8, // topic name LSB (8)
'/', 'a', '/', 'b', '/', '#', '/', 'c',
1, // QoS
0, // topic name MSB (0)
10, // topic name LSB (10)
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
2, // QoS
}
msg := NewSubscribeMessage()
msg.SetPacketId(7)
msg.AddTopic([]byte("surgemq"), 0)
msg.AddTopic([]byte("/a/b/#/c"), 1)
msg.AddTopic([]byte("/a/b/#/cdd"), 2)
dst := make([]byte, 100)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestSubscribeDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(SUBSCRIBE<<4) | 2,
36,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // QoS
0, // topic name MSB (0)
8, // topic name LSB (8)
'/', 'a', '/', 'b', '/', '#', '/', 'c',
1, // QoS
0, // topic name MSB (0)
10, // topic name LSB (10)
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
2, // QoS
}
msg := NewSubscribeMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,31 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
// The UNSUBACK Packet is sent by the Server to the Client to confirm receipt of an
// UNSUBSCRIBE Packet.
type UnsubackMessage struct {
PubackMessage
}
var _ Message = (*UnsubackMessage)(nil)
// NewUnsubackMessage creates a new UNSUBACK message.
func NewUnsubackMessage() *UnsubackMessage {
msg := &UnsubackMessage{}
msg.SetType(UNSUBACK)
return msg
}

View File

@@ -1,108 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUnsubackMessageFields(t *testing.T) {
msg := NewUnsubackMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()))
}
func TestUnsubackMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(UNSUBACK << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewUnsubackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, UNSUBACK, msg.Type(), "Error decoding message.")
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
}
// test insufficient bytes
func TestUnsubackMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(UNSUBACK << 4),
2,
7, // packet ID LSB (7)
}
msg := NewUnsubackMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestUnsubackMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(UNSUBACK << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewUnsubackMessage()
msg.SetPacketId(7)
dst := make([]byte, 10)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestUnsubackDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(UNSUBACK << 4),
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewUnsubackMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

View File

@@ -1,210 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"bytes"
"fmt"
"sync/atomic"
)
// An UNSUBSCRIBE Packet is sent by the Client to the Server, to unsubscribe from topics.
type UnsubscribeMessage struct {
header
topics [][]byte
}
var _ Message = (*UnsubscribeMessage)(nil)
// NewUnsubscribeMessage creates a new UNSUBSCRIBE message.
func NewUnsubscribeMessage() *UnsubscribeMessage {
msg := &UnsubscribeMessage{}
msg.SetType(UNSUBSCRIBE)
return msg
}
func (this UnsubscribeMessage) String() string {
msgstr := fmt.Sprintf("%s", this.header)
for i, t := range this.topics {
msgstr = fmt.Sprintf("%s, Topic%d=%s", msgstr, i, string(t))
}
return msgstr
}
// Topics returns a list of topics sent by the Client.
func (this *UnsubscribeMessage) Topics() [][]byte {
return this.topics
}
// AddTopic adds a single topic to the message.
func (this *UnsubscribeMessage) AddTopic(topic []byte) {
if this.TopicExists(topic) {
return
}
this.topics = append(this.topics, topic)
this.dirty = true
}
// RemoveTopic removes a single topic from the list of existing ones in the message.
// If topic does not exist it just does nothing.
func (this *UnsubscribeMessage) RemoveTopic(topic []byte) {
var i int
var t []byte
var found bool
for i, t = range this.topics {
if bytes.Equal(t, topic) {
found = true
break
}
}
if found {
this.topics = append(this.topics[:i], this.topics[i+1:]...)
}
this.dirty = true
}
// TopicExists checks to see if a topic exists in the list.
func (this *UnsubscribeMessage) TopicExists(topic []byte) bool {
for _, t := range this.topics {
if bytes.Equal(t, topic) {
return true
}
}
return false
}
func (this *UnsubscribeMessage) Len() int {
if !this.dirty {
return len(this.dbuf)
}
ml := this.msglen()
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0
}
return this.header.msglen() + ml
}
// Decode reads from the io.Reader parameter until a full message is decoded, or
// when io.Reader returns EOF or error. The first return value is the number of
// bytes read from io.Reader. The second is error if Decode encounters any problems.
func (this *UnsubscribeMessage) Decode(src []byte) (int, error) {
total := 0
hn, err := this.header.decode(src[total:])
total += hn
if err != nil {
return total, err
}
//this.packetId = binary.BigEndian.Uint16(src[total:])
this.packetId = src[total : total+2]
total += 2
remlen := int(this.remlen) - (total - hn)
for remlen > 0 {
t, n, err := readLPBytes(src[total:])
total += n
if err != nil {
return total, err
}
this.topics = append(this.topics, t)
remlen = remlen - n - 1
}
if len(this.topics) == 0 {
return 0, fmt.Errorf("unsubscribe/Decode: Empty topic list")
}
this.dirty = false
return total, nil
}
// Encode returns an io.Reader in which the encoded bytes can be read. The second
// return value is the number of bytes encoded, so the caller knows how many bytes
// there will be. If Encode returns an error, then the first two return values
// should be considered invalid.
// Any changes to the message after Encode() is called will invalidate the io.Reader.
func (this *UnsubscribeMessage) Encode(dst []byte) (int, error) {
if !this.dirty {
if len(dst) < len(this.dbuf) {
return 0, fmt.Errorf("unsubscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
}
return copy(dst, this.dbuf), nil
}
hl := this.header.msglen()
ml := this.msglen()
if len(dst) < hl+ml {
return 0, fmt.Errorf("unsubscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
}
if err := this.SetRemainingLength(int32(ml)); err != nil {
return 0, err
}
total := 0
n, err := this.header.encode(dst[total:])
total += n
if err != nil {
return total, err
}
if this.PacketId() == 0 {
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
}
n = copy(dst[total:], this.packetId)
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
total += n
for _, t := range this.topics {
n, err := writeLPBytes(dst[total:], t)
total += n
if err != nil {
return total, err
}
}
return total, nil
}
func (this *UnsubscribeMessage) msglen() int {
// packet ID
total := 2
for _, t := range this.topics {
total += 2 + len(t)
}
return total
}

View File

@@ -1,155 +0,0 @@
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package message
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestUnsubscribeMessageFields(t *testing.T) {
msg := NewUnsubscribeMessage()
msg.SetPacketId(100)
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
msg.AddTopic([]byte("/a/b/#/c"))
require.Equal(t, 1, len(msg.Topics()), "Error adding topic.")
msg.AddTopic([]byte("/a/b/#/c"))
require.Equal(t, 1, len(msg.Topics()), "Error adding duplicate topic.")
msg.RemoveTopic([]byte("/a/b/#/c"))
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
require.False(t, msg.TopicExists([]byte("a/b")), "Topic should not exist.")
msg.RemoveTopic([]byte("/a/b/#/c"))
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
}
func TestUnsubscribeMessageDecode(t *testing.T) {
msgBytes := []byte{
byte(UNSUBSCRIBE<<4) | 2,
33,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // topic name MSB (0)
8, // topic name LSB (8)
'/', 'a', '/', 'b', '/', '#', '/', 'c',
0, // topic name MSB (0)
10, // topic name LSB (10)
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
}
msg := NewUnsubscribeMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, UNSUBSCRIBE, msg.Type(), "Error decoding message.")
require.Equal(t, 3, len(msg.Topics()), "Error decoding topics.")
require.True(t, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.")
require.True(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.")
require.True(t, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.")
}
// test empty topic list
func TestUnsubscribeMessageDecode2(t *testing.T) {
msgBytes := []byte{
byte(UNSUBSCRIBE<<4) | 2,
2,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
}
msg := NewUnsubscribeMessage()
_, err := msg.Decode(msgBytes)
require.Error(t, err)
}
func TestUnsubscribeMessageEncode(t *testing.T) {
msgBytes := []byte{
byte(UNSUBSCRIBE<<4) | 2,
33,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // topic name MSB (0)
8, // topic name LSB (8)
'/', 'a', '/', 'b', '/', '#', '/', 'c',
0, // topic name MSB (0)
10, // topic name LSB (10)
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
}
msg := NewUnsubscribeMessage()
msg.SetPacketId(7)
msg.AddTopic([]byte("surgemq"))
msg.AddTopic([]byte("/a/b/#/c"))
msg.AddTopic([]byte("/a/b/#/cdd"))
dst := make([]byte, 100)
n, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
}
// test to ensure encoding and decoding are the same
// decode, encode, and decode again
func TestUnsubscribeDecodeEncodeEquiv(t *testing.T) {
msgBytes := []byte{
byte(UNSUBSCRIBE<<4) | 2,
33,
0, // packet ID MSB (0)
7, // packet ID LSB (7)
0, // topic name MSB (0)
7, // topic name LSB (7)
's', 'u', 'r', 'g', 'e', 'm', 'q',
0, // topic name MSB (0)
8, // topic name LSB (8)
'/', 'a', '/', 'b', '/', '#', '/', 'c',
0, // topic name MSB (0)
10, // topic name LSB (10)
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
}
msg := NewUnsubscribeMessage()
n, err := msg.Decode(msgBytes)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n, "Error decoding message.")
dst := make([]byte, 100)
n2, err := msg.Encode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
n3, err := msg.Decode(dst)
require.NoError(t, err, "Error decoding message.")
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
}

51
logger/logger.go Normal file
View File

@@ -0,0 +1,51 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package logger
import (
"go.uber.org/zap"
)
var (
// env can be setup at build time with Go Linker. Value could be prod or whatever else for dev env
env string
instance *zap.Logger
logCfg zap.Config
)
// NewDevLogger return a logger for dev builds
func NewDevLogger() (*zap.Logger, error) {
logCfg := zap.NewDevelopmentConfig()
return logCfg.Build()
}
// NewProdLogger return a logger for production builds
func NewProdLogger() (*zap.Logger, error) {
logCfg := zap.NewProductionConfig()
logCfg.DisableStacktrace = true
logCfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel)
return logCfg.Build()
}
func init() {
var err error
var log *zap.Logger
if env == "prod" {
log, err = NewProdLogger()
} else {
log, err = NewDevLogger()
}
if err != nil {
panic("Unable to create a logger.")
}
defer log.Sync()
log.Debug("Logger initialization succeeded")
instance = log.Named("hmq")
}
// Get return a *zap.Logger instance
func Get() *zap.Logger {
return instance
}

33
logger/logger_test.go Normal file
View File

@@ -0,0 +1,33 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
*/
package logger
import (
"testing"
"github.com/stretchr/testify/assert"
"go.uber.org/zap"
)
func TestGet(t *testing.T) {
var l *zap.Logger
logger := Get()
assert.NotNil(t, logger)
assert.IsType(t, l, logger)
}
func TestNewDevLogger(t *testing.T) {
logger, err := NewDevLogger()
assert.Nil(t, err)
assert.True(t, logger.Core().Enabled(zap.DebugLevel))
}
func TestNewProdLogger(t *testing.T) {
logger, err := NewProdLogger()
assert.Nil(t, err)
assert.False(t, logger.Core().Enabled(zap.DebugLevel))
}

31
main.go
View File

@@ -1,29 +1,42 @@
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
Permission to use, copy, modify, and/or distribute this software for any
purpose with or without fee is hereby granted, provided that the above
copyright notice and this permission notice appear in all copies.
*/
package main
import (
"hmq/broker"
"os"
"os/signal"
"runtime"
log "github.com/cihub/seelog"
"github.com/fhmq/hmq/broker"
"github.com/fhmq/hmq/logger"
"go.uber.org/zap"
)
var (
log = logger.Get().Named("Main")
)
func main() {
config, er := broker.LoadConfig()
if er != nil {
log.Error("Load Config file error: ", er)
runtime.GOMAXPROCS(runtime.NumCPU())
config, err := broker.ConfigureConfig()
if err != nil {
log.Error("configure broker config error: ", zap.Error(err))
return
}
broker, err := broker.NewBroker(config)
b, err := broker.NewBroker(config)
if err != nil {
log.Error("New Broker error: ", er)
log.Error("New Broker error: ", zap.Error(err))
return
}
broker.Start()
b.Start()
s := waitForSignal()
log.Infof("signal got: %v ,broker closed.", s)
log.Info("signal received, broker closed.", zap.Any("signal", s))
}
func waitForSignal() os.Signal {