mirror of
https://github.com/fhmq/hmq.git
synced 2026-05-06 07:35:32 +00:00
Compare commits
85 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| daf4a0e0f5 | |||
| c350d16ca1 | |||
| edc46c1ee6 | |||
| 6193be74fa | |||
| 90beada459 | |||
| 6c7fe6a0f7 | |||
| 2b56664d85 | |||
| 7547ad3bdc | |||
| 84e7fe2490 | |||
| 684584b208 | |||
| 56fb4a2d54 | |||
| 5ed4728575 | |||
| c0fea6a5ba | |||
| 47500910e1 | |||
| 0ff20b6ee2 | |||
| 7155667f6c | |||
| 83db82cdcc | |||
| b3653bcfb1 | |||
| 221d00480e | |||
| 91733bf91e | |||
| ef252550dc | |||
| 1058256235 | |||
| 5a569f14a3 | |||
| 93b21777ff | |||
| dcf2934e1b | |||
| d9e6e216b0 | |||
| ca3951769a | |||
| 0439e7ce90 | |||
| dc0f2185ab | |||
| 7462afcfb5 | |||
| 114e6f901e | |||
| 0cb51bd37a | |||
| 819b4725f2 | |||
| 85bdeccbfc | |||
| 1339a04b28 | |||
| 957329d85c | |||
| 7db7edaa17 | |||
| 1d6f6a4a71 | |||
| 123bb7210f | |||
| 9ad6590e83 | |||
| 516db49db5 | |||
| a260057bfe | |||
| bdd802ebfb | |||
| 5786e69b01 | |||
| 6a89b627d4 | |||
| 208a7cf0a8 | |||
| a7fb7f1912 | |||
| eeab0c6b7d | |||
| 4646042b7f | |||
| 49385e52fd | |||
| 3ed8625bb9 | |||
| 6b50060eae | |||
| 96277996f0 | |||
| 5601632a33 | |||
| cc1b3239ad | |||
| 476d22568b | |||
| c85ba76f8f | |||
| f3b2924b07 | |||
| 6144aeb6bf | |||
| 83b6934621 | |||
| 7d6fcb7d65 | |||
| 35390baa92 | |||
| f2efaa9992 | |||
| 258912b33a | |||
| 7f45bd6bc9 | |||
| 7073e9b4ba | |||
| 8c98346546 | |||
| 4300a32f6b | |||
| ae1af54c6e | |||
| 34378164e0 | |||
| 47c44570fc | |||
| 417a12174c | |||
| 43a6bb8c5d | |||
| 18d18738be | |||
| 8af790ffba | |||
| 5e937601ce | |||
| 31548b10e5 | |||
| ca7ebfb6e3 | |||
| b98ae9ec6f | |||
| 8bf6ccaa25 | |||
| 65ac09cf50 | |||
| d37100d059 | |||
| 50a9a6841d | |||
| a45cccaa7a | |||
| c732d395e1 |
@@ -1,3 +1,4 @@
|
|||||||
hmq
|
hmq
|
||||||
log
|
log
|
||||||
log/*
|
log/*
|
||||||
|
*.test
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
FROM alpine
|
FROM alpine
|
||||||
COPY hmq /
|
COPY hmq /
|
||||||
COPY broker.config /
|
COPY ssl /ssl
|
||||||
COPY tls /tls
|
|
||||||
COPY conf /conf
|
COPY conf /conf
|
||||||
|
|
||||||
EXPOSE 1883
|
EXPOSE 1883
|
||||||
EXPOSE 1888
|
EXPOSE 1888
|
||||||
|
EXPOSE 8883
|
||||||
EXPOSE 1993
|
EXPOSE 1993
|
||||||
|
|
||||||
CMD ["/hmq"]
|
CMD ["/hmq"]
|
||||||
Executable → Regular
+1
-1
@@ -1,4 +1,4 @@
|
|||||||
Apache License
|
Apache License
|
||||||
Version 2.0, January 2004
|
Version 2.0, January 2004
|
||||||
http://www.apache.org/licenses/
|
http://www.apache.org/licenses/
|
||||||
|
|
||||||
@@ -3,25 +3,52 @@ Free and High Performance MQTT Broker
|
|||||||
|
|
||||||
## About
|
## About
|
||||||
Golang MQTT Broker, Version 3.1.1, and Compatible
|
Golang MQTT Broker, Version 3.1.1, and Compatible
|
||||||
for [eclipse paho client](https://github.com/eclipse?utf8=%E2%9C%93&q=mqtt&type=&language=)
|
for [eclipse paho client](https://github.com/eclipse?utf8=%E2%9C%93&q=mqtt&type=&language=) and mosquitto-client
|
||||||
|
|
||||||
|
Download: [click here](https://github.com/fhmq/hmq/releases)
|
||||||
|
|
||||||
## RUNNING
|
## RUNNING
|
||||||
```bash
|
```bash
|
||||||
$ git clone https://github.com/fhmq/hmq.git
|
$ go get github.com/fhmq/hmq
|
||||||
$ cd hmq
|
$ cd $GOPATH/github.com/fhmq/hmq
|
||||||
$ go run main.go
|
$ go run main.go
|
||||||
```
|
```
|
||||||
|
|
||||||
### broker.config
|
## Usage of hmq:
|
||||||
|
~~~
|
||||||
|
Usage: hmq [options]
|
||||||
|
|
||||||
|
Broker Options:
|
||||||
|
-w, --worker <number> Worker num to process message, perfer (client num)/10. (default 1024)
|
||||||
|
-p, --port <port> Use port for clients (default: 1883)
|
||||||
|
--host <host> Network host to listen on. (default "0.0.0.0")
|
||||||
|
-ws, --wsport <port> Use port for websocket monitoring
|
||||||
|
-wsp,--wspath <path> Use path for websocket monitoring
|
||||||
|
-c, --config <file> Configuration file
|
||||||
|
|
||||||
|
Logging Options:
|
||||||
|
-d, --debug <bool> Enable debugging output (default false)
|
||||||
|
-D Debug enabled
|
||||||
|
|
||||||
|
Cluster Options:
|
||||||
|
-r, --router <rurl> Router who maintenance cluster info
|
||||||
|
-cp, --clusterport <cluster-port> Cluster listen port for others
|
||||||
|
|
||||||
|
Common Options:
|
||||||
|
-h, --help Show this message
|
||||||
|
~~~
|
||||||
|
|
||||||
|
### hmq.config
|
||||||
~~~
|
~~~
|
||||||
{
|
{
|
||||||
|
"workerNum": 4096,
|
||||||
"port": "1883",
|
"port": "1883",
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"cluster": {
|
"cluster": {
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": "1993",
|
"port": "1993"
|
||||||
"routers": ["10.10.0.11:1993","10.10.0.12:1993"]
|
|
||||||
},
|
},
|
||||||
|
"router": "127.0.0.1:9888",
|
||||||
"wsPort": "1888",
|
"wsPort": "1888",
|
||||||
"wsPath": "/ws",
|
"wsPath": "/ws",
|
||||||
"wsTLS": true,
|
"wsTLS": true,
|
||||||
@@ -40,7 +67,7 @@ $ go run main.go
|
|||||||
|
|
||||||
### Features and Future
|
### Features and Future
|
||||||
|
|
||||||
* Supports QOS 0
|
* Supports QOS 0 and 1
|
||||||
|
|
||||||
* Cluster Support
|
* Cluster Support
|
||||||
|
|
||||||
@@ -50,20 +77,21 @@ $ go run main.go
|
|||||||
|
|
||||||
* Supports will messages
|
* Supports will messages
|
||||||
|
|
||||||
* Queue subscribe
|
|
||||||
|
|
||||||
* Websocket Support
|
* Websocket Support
|
||||||
|
|
||||||
* TLS/SSL Support
|
* TLS/SSL Support
|
||||||
|
|
||||||
* Flexible ACL
|
* Flexible ACL
|
||||||
|
|
||||||
### QUEUE SUBSCRIBE
|
### Cluster
|
||||||
~~~
|
```bash
|
||||||
| Prefix | Examples |
|
1, start router for hmq (https://github.com/fhmq/router.git)
|
||||||
| ------------- |---------------------------------|
|
$ go get github.com/fhmq/router
|
||||||
| $queue/ | mosquitto_sub -t ‘$queue/topic’ |
|
$ cd $GOPATH/github.com/fhmq/router
|
||||||
~~~
|
$ go run main.go
|
||||||
|
2, config router in hmq.config ("router": "127.0.0.1:9888")
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
### ACL Configure
|
### ACL Configure
|
||||||
#### The ACL rules define:
|
#### The ACL rules define:
|
||||||
@@ -117,6 +145,14 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | -->
|
|||||||
allow | deny allow | deny allow | deny
|
allow | deny allow | deny allow | deny
|
||||||
~~~
|
~~~
|
||||||
|
|
||||||
|
### Online/Offline Notification
|
||||||
|
```bash
|
||||||
|
topic:
|
||||||
|
$SYS/broker/connection/clients/<clientID>
|
||||||
|
payload:
|
||||||
|
{"clientID":"client001","online":true/false,"timestamp":"2018-10-25T09:32:32Z"}
|
||||||
|
```
|
||||||
|
|
||||||
## Performance
|
## Performance
|
||||||
|
|
||||||
* High throughput
|
* High throughput
|
||||||
@@ -129,3 +165,8 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | -->
|
|||||||
## License
|
## License
|
||||||
|
|
||||||
* Apache License Version 2.0
|
* Apache License Version 2.0
|
||||||
|
|
||||||
|
|
||||||
|
## Reference
|
||||||
|
|
||||||
|
* Surgermq.(https://github.com/surgemq/surgemq)
|
||||||
@@ -1 +0,0 @@
|
|||||||
theme: jekyll-theme-slate
|
|
||||||
+11
-10
@@ -1,11 +1,12 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package broker
|
package broker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"hmq/lib/acl"
|
"github.com/fhmq/hmq/lib/acl"
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/cihub/seelog"
|
|
||||||
"github.com/fsnotify/fsnotify"
|
"github.com/fsnotify/fsnotify"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -14,7 +15,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func (c *client) CheckTopicAuth(typ int, topic string) bool {
|
func (c *client) CheckTopicAuth(typ int, topic string) bool {
|
||||||
if !c.broker.config.Acl {
|
if c.typ != CLIENT || !c.broker.config.Acl {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
if strings.HasPrefix(topic, "$queue/") {
|
if strings.HasPrefix(topic, "$queue/") {
|
||||||
@@ -40,10 +41,10 @@ func (b *Broker) handleFsEvent(event fsnotify.Event) error {
|
|||||||
case b.config.AclConf:
|
case b.config.AclConf:
|
||||||
if event.Op&fsnotify.Write == fsnotify.Write ||
|
if event.Op&fsnotify.Write == fsnotify.Write ||
|
||||||
event.Op&fsnotify.Create == fsnotify.Create {
|
event.Op&fsnotify.Create == fsnotify.Create {
|
||||||
log.Info("text:handling acl config change event:", event)
|
log.Info("text:handling acl config change event:", zap.String("filename", event.Name))
|
||||||
aclconfig, err := acl.AclConfigLoad(event.Name)
|
aclconfig, err := acl.AclConfigLoad(event.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("aclconfig change failed, load acl conf error: ", err)
|
log.Error("aclconfig change failed, load acl conf error: ", zap.Error(err))
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
b.AclConfig = aclconfig
|
b.AclConfig = aclconfig
|
||||||
@@ -56,14 +57,14 @@ func (b *Broker) StartAclWatcher() {
|
|||||||
go func() {
|
go func() {
|
||||||
wch, e := fsnotify.NewWatcher()
|
wch, e := fsnotify.NewWatcher()
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Error("start monitor acl config file error,", e)
|
log.Error("start monitor acl config file error,", zap.Error(e))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer wch.Close()
|
defer wch.Close()
|
||||||
|
|
||||||
for _, i := range watchList {
|
for _, i := range watchList {
|
||||||
if err := wch.Add(i); err != nil {
|
if err := wch.Add(i); err != nil {
|
||||||
log.Error("start monitor acl config file error,", err)
|
log.Error("start monitor acl config file error,", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -73,7 +74,7 @@ func (b *Broker) StartAclWatcher() {
|
|||||||
case evt := <-wch.Events:
|
case evt := <-wch.Events:
|
||||||
b.handleFsEvent(evt)
|
b.handleFsEvent(evt)
|
||||||
case err := <-wch.Errors:
|
case err := <-wch.Errors:
|
||||||
log.Error("error:", err.Error())
|
log.Error("error:", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|||||||
+489
-191
@@ -1,48 +1,93 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package broker
|
package broker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"hmq/lib/acl"
|
"fmt"
|
||||||
"hmq/lib/message"
|
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
|
"runtime/debug"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||||
|
"github.com/fhmq/hmq/lib/acl"
|
||||||
|
"github.com/fhmq/hmq/lib/sessions"
|
||||||
|
"github.com/fhmq/hmq/lib/topics"
|
||||||
|
"github.com/fhmq/hmq/pool"
|
||||||
|
"github.com/shirou/gopsutil/mem"
|
||||||
|
"go.uber.org/zap"
|
||||||
"golang.org/x/net/websocket"
|
"golang.org/x/net/websocket"
|
||||||
|
|
||||||
log "github.com/cihub/seelog"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
MessagePoolNum = 1024
|
||||||
|
MessagePoolMessageNum = 1024
|
||||||
|
)
|
||||||
|
|
||||||
|
type Message struct {
|
||||||
|
client *client
|
||||||
|
packet packets.ControlPacket
|
||||||
|
}
|
||||||
|
|
||||||
type Broker struct {
|
type Broker struct {
|
||||||
id string
|
id string
|
||||||
cid uint64
|
cid uint64
|
||||||
config *Config
|
mu sync.Mutex
|
||||||
tlsConfig *tls.Config
|
config *Config
|
||||||
AclConfig *acl.ACLConfig
|
tlsConfig *tls.Config
|
||||||
clients cMap
|
AclConfig *acl.ACLConfig
|
||||||
routes cMap
|
wpool *pool.WorkerPool
|
||||||
remotes cMap
|
clients sync.Map
|
||||||
sl *Sublist
|
routes sync.Map
|
||||||
rl *RetainList
|
remotes sync.Map
|
||||||
queues map[string]int
|
nodes map[string]interface{}
|
||||||
|
clusterPool chan *Message
|
||||||
|
queues map[string]int
|
||||||
|
topicsMgr *topics.Manager
|
||||||
|
sessionMgr *sessions.Manager
|
||||||
|
// messagePool []chan *Message
|
||||||
|
}
|
||||||
|
|
||||||
|
func newMessagePool() []chan *Message {
|
||||||
|
pool := make([]chan *Message, 0)
|
||||||
|
for i := 0; i < MessagePoolNum; i++ {
|
||||||
|
ch := make(chan *Message, MessagePoolMessageNum)
|
||||||
|
pool = append(pool, ch)
|
||||||
|
}
|
||||||
|
return pool
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBroker(config *Config) (*Broker, error) {
|
func NewBroker(config *Config) (*Broker, error) {
|
||||||
b := &Broker{
|
b := &Broker{
|
||||||
id: GenUniqueId(),
|
id: GenUniqueId(),
|
||||||
config: config,
|
config: config,
|
||||||
sl: NewSublist(),
|
wpool: pool.New(config.Worker),
|
||||||
rl: NewRetainList(),
|
nodes: make(map[string]interface{}),
|
||||||
queues: make(map[string]int),
|
queues: make(map[string]int),
|
||||||
clients: NewClientMap(),
|
clusterPool: make(chan *Message),
|
||||||
routes: NewClientMap(),
|
|
||||||
remotes: NewClientMap(),
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
b.topicsMgr, err = topics.NewManager("mem")
|
||||||
|
if err != nil {
|
||||||
|
log.Error("new topic manager error", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
b.sessionMgr, err = sessions.NewManager("mem")
|
||||||
|
if err != nil {
|
||||||
|
log.Error("new session manager error", zap.Error(err))
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
if b.config.TlsPort != "" {
|
if b.config.TlsPort != "" {
|
||||||
tlsconfig, err := NewTLSConfig(b.config.TlsInfo)
|
tlsconfig, err := NewTLSConfig(b.config.TlsInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("new tlsConfig error: ", err)
|
log.Error("new tlsConfig error", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
b.tlsConfig = tlsconfig
|
b.tlsConfig = tlsconfig
|
||||||
@@ -50,7 +95,7 @@ func NewBroker(config *Config) (*Broker, error) {
|
|||||||
if b.config.Acl {
|
if b.config.Acl {
|
||||||
aclconfig, err := acl.AclConfigLoad(b.config.AclConf)
|
aclconfig, err := acl.AclConfigLoad(b.config.AclConf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Load acl conf error: ", err)
|
log.Error("Load acl conf error", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
b.AclConfig = aclconfig
|
b.AclConfig = aclconfig
|
||||||
@@ -59,49 +104,120 @@ func NewBroker(config *Config) (*Broker, error) {
|
|||||||
return b, nil
|
return b, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (b *Broker) SubmitWork(clientId string, msg *Message) {
|
||||||
|
if b.wpool == nil {
|
||||||
|
b.wpool = pool.New(b.config.Worker)
|
||||||
|
}
|
||||||
|
|
||||||
|
if msg.client.typ == CLUSTER {
|
||||||
|
b.clusterPool <- msg
|
||||||
|
} else {
|
||||||
|
b.wpool.Submit(clientId, func() {
|
||||||
|
ProcessMessage(msg)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
func (b *Broker) Start() {
|
func (b *Broker) Start() {
|
||||||
if b == nil {
|
if b == nil {
|
||||||
log.Error("broker is null")
|
log.Error("broker is null")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//listen clinet over tcp
|
||||||
if b.config.Port != "" {
|
if b.config.Port != "" {
|
||||||
go b.StartListening(CLIENT)
|
go b.StartClientListening(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//listen for cluster
|
||||||
if b.config.Cluster.Port != "" {
|
if b.config.Cluster.Port != "" {
|
||||||
go b.StartListening(ROUTER)
|
go b.StartClusterListening()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//listen for websocket
|
||||||
if b.config.WsPort != "" {
|
if b.config.WsPort != "" {
|
||||||
go b.StartWebsocketListening()
|
go b.StartWebsocketListening()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//listen client over tls
|
||||||
if b.config.TlsPort != "" {
|
if b.config.TlsPort != "" {
|
||||||
go b.StartTLSListening()
|
go b.StartClientListening(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
//connect on other node in cluster
|
||||||
|
if b.config.Router != "" {
|
||||||
|
go b.processClusterInfo()
|
||||||
|
b.ConnectToDiscovery()
|
||||||
|
}
|
||||||
|
|
||||||
|
//system monitor
|
||||||
|
go StateMonitor()
|
||||||
|
|
||||||
|
if b.config.Debug {
|
||||||
|
startPProf()
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func startPProf() {
|
||||||
|
go func() {
|
||||||
|
http.ListenAndServe(":10060", nil)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func StateMonitor() {
|
||||||
|
v, _ := mem.VirtualMemory()
|
||||||
|
timeSticker := time.NewTicker(time.Second * 30)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-timeSticker.C:
|
||||||
|
if v.UsedPercent > 75 {
|
||||||
|
debug.FreeOSMemory()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) StartWebsocketListening() {
|
func (b *Broker) StartWebsocketListening() {
|
||||||
path := b.config.WsPath
|
path := b.config.WsPath
|
||||||
hp := ":" + b.config.WsPort
|
hp := ":" + b.config.WsPort
|
||||||
log.Info("Start Webscoker Listening on ", hp, path)
|
log.Info("Start Websocket Listener on:", zap.String("hp", hp), zap.String("path", path))
|
||||||
http.Handle(path, websocket.Handler(b.wsHandler))
|
http.Handle(path, websocket.Handler(b.wsHandler))
|
||||||
err := http.ListenAndServe(hp, nil)
|
var err error
|
||||||
|
if b.config.WsTLS {
|
||||||
|
err = http.ListenAndServeTLS(hp, b.config.TlsInfo.CertFile, b.config.TlsInfo.KeyFile, nil)
|
||||||
|
} else {
|
||||||
|
err = http.ListenAndServe(hp, nil)
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("ListenAndServe: " + err.Error())
|
log.Error("ListenAndServe:" + err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) wsHandler(ws *websocket.Conn) {
|
func (b *Broker) wsHandler(ws *websocket.Conn) {
|
||||||
|
// io.Copy(ws, ws)
|
||||||
atomic.AddUint64(&b.cid, 1)
|
atomic.AddUint64(&b.cid, 1)
|
||||||
go b.handleConnection(CLIENT, ws, b.cid)
|
ws.PayloadType = websocket.BinaryFrame
|
||||||
|
b.handleConnection(CLIENT, ws)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) StartTLSListening() {
|
func (b *Broker) StartClientListening(Tls bool) {
|
||||||
hp := b.config.TlsHost + ":" + b.config.TlsPort
|
var hp string
|
||||||
log.Info("Start TLS Listening client on ", hp)
|
var err error
|
||||||
|
var l net.Listener
|
||||||
l, e := tls.Listen("tcp", hp, b.tlsConfig)
|
if Tls {
|
||||||
if e != nil {
|
hp = b.config.TlsHost + ":" + b.config.TlsPort
|
||||||
log.Error("Error listening on ", e)
|
l, err = tls.Listen("tcp", hp, b.tlsConfig)
|
||||||
|
log.Info("Start TLS Listening client on ", zap.String("hp", hp))
|
||||||
|
} else {
|
||||||
|
hp := b.config.Host + ":" + b.config.Port
|
||||||
|
l, err = net.Listen("tcp", hp)
|
||||||
|
log.Info("Start Listening client on ", zap.String("hp", hp))
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error listening on ", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
tmpDelay := 10 * ACCEPT_MIN_SLEEP
|
tmpDelay := 10 * ACCEPT_MIN_SLEEP
|
||||||
@@ -110,36 +226,60 @@ func (b *Broker) StartTLSListening() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||||
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
|
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
|
||||||
ne, tmpDelay/time.Millisecond)
|
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
|
||||||
time.Sleep(tmpDelay)
|
time.Sleep(tmpDelay)
|
||||||
tmpDelay *= 2
|
tmpDelay *= 2
|
||||||
if tmpDelay > ACCEPT_MAX_SLEEP {
|
if tmpDelay > ACCEPT_MAX_SLEEP {
|
||||||
tmpDelay = ACCEPT_MAX_SLEEP
|
tmpDelay = ACCEPT_MAX_SLEEP
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Error("Accept error: %v", err)
|
log.Error("Accept error: %v", zap.Error(err))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tmpDelay = ACCEPT_MIN_SLEEP
|
tmpDelay = ACCEPT_MIN_SLEEP
|
||||||
atomic.AddUint64(&b.cid, 1)
|
atomic.AddUint64(&b.cid, 1)
|
||||||
go b.handleConnection(CLIENT, conn, b.cid)
|
go b.handleConnection(CLIENT, conn)
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) StartListening(typ int) {
|
func (b *Broker) Handshake(conn net.Conn) bool {
|
||||||
var hp string
|
|
||||||
if typ == CLIENT {
|
nc := tls.Server(conn, b.tlsConfig)
|
||||||
hp = b.config.Host + ":" + b.config.Port
|
time.AfterFunc(DEFAULT_TLS_TIMEOUT, func() { TlsTimeout(nc) })
|
||||||
log.Info("Start Listening client on ", hp)
|
nc.SetReadDeadline(time.Now().Add(DEFAULT_TLS_TIMEOUT))
|
||||||
} else if typ == ROUTER {
|
|
||||||
hp = b.config.Cluster.Host + ":" + b.config.Cluster.Port
|
// Force handshake
|
||||||
log.Info("Start Listening cluster on ", hp)
|
if err := nc.Handshake(); err != nil {
|
||||||
|
log.Error("TLS handshake error, ", zap.Error(err))
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
nc.SetReadDeadline(time.Time{})
|
||||||
|
return true
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TlsTimeout(conn *tls.Conn) {
|
||||||
|
nc := conn
|
||||||
|
// Check if already closed
|
||||||
|
if nc == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
cs := nc.ConnectionState()
|
||||||
|
if !cs.HandshakeComplete {
|
||||||
|
log.Error("TLS handshake timeout")
|
||||||
|
nc.Close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) StartClusterListening() {
|
||||||
|
var hp string = b.config.Cluster.Host + ":" + b.config.Cluster.Port
|
||||||
|
log.Info("Start Listening cluster on ", zap.String("hp", hp))
|
||||||
|
|
||||||
l, e := net.Listen("tcp", hp)
|
l, e := net.Listen("tcp", hp)
|
||||||
if e != nil {
|
if e != nil {
|
||||||
log.Error("Error listening on ", e)
|
log.Error("Error listening on ", zap.Error(e))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,60 +289,66 @@ func (b *Broker) StartListening(typ int) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||||
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
|
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
|
||||||
ne, tmpDelay/time.Millisecond)
|
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
|
||||||
time.Sleep(tmpDelay)
|
time.Sleep(tmpDelay)
|
||||||
tmpDelay *= 2
|
tmpDelay *= 2
|
||||||
if tmpDelay > ACCEPT_MAX_SLEEP {
|
if tmpDelay > ACCEPT_MAX_SLEEP {
|
||||||
tmpDelay = ACCEPT_MAX_SLEEP
|
tmpDelay = ACCEPT_MAX_SLEEP
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Error("Accept error: %v", err)
|
log.Error("Accept error: %v", zap.Error(err))
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
tmpDelay = ACCEPT_MIN_SLEEP
|
tmpDelay = ACCEPT_MIN_SLEEP
|
||||||
atomic.AddUint64(&b.cid, 1)
|
|
||||||
go b.handleConnection(typ, conn, b.cid)
|
go b.handleConnection(ROUTER, conn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
|
func (b *Broker) handleConnection(typ int, conn net.Conn) {
|
||||||
//process connect packet
|
//process connect packet
|
||||||
buf, err := ReadPacket(conn)
|
packet, err := packets.ReadPacket(conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("read connect packet error: ", err)
|
log.Error("read connect packet error: ", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
connMsg, err := DecodeConnectMessage(buf)
|
if packet == nil {
|
||||||
if err != nil {
|
log.Error("received nil packet")
|
||||||
log.Error(err)
|
return
|
||||||
|
}
|
||||||
|
msg, ok := packet.(*packets.ConnectPacket)
|
||||||
|
if !ok {
|
||||||
|
log.Error("received msg that was not Connect")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
connack := message.NewConnackMessage()
|
log.Info("reconnect connect from ", zap.String("clientID", msg.ClientIdentifier))
|
||||||
connack.SetReturnCode(message.ConnectionAccepted)
|
|
||||||
ack, _ := EncodeMessage(connack)
|
connack := packets.NewControlPacket(packets.Connack).(*packets.ConnackPacket)
|
||||||
err1 := WriteBuffer(conn, ack)
|
connack.ReturnCode = packets.Accepted
|
||||||
if err1 != nil {
|
connack.SessionPresent = msg.CleanSession
|
||||||
log.Error("send connack error, ", err1)
|
err = connack.Write(conn)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("send connack error, ", zap.Error(err), zap.String("clientID", msg.ClientIdentifier))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
willmsg := message.NewPublishMessage()
|
willmsg := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||||
if connMsg.WillFlag() {
|
if msg.WillFlag {
|
||||||
willmsg.SetQoS(connMsg.WillQos())
|
willmsg.Qos = msg.WillQos
|
||||||
willmsg.SetPayload(connMsg.WillMessage())
|
willmsg.TopicName = msg.WillTopic
|
||||||
willmsg.SetRetain(connMsg.WillRetain())
|
willmsg.Retain = msg.WillRetain
|
||||||
willmsg.SetTopic(connMsg.WillTopic())
|
willmsg.Payload = msg.WillMessage
|
||||||
willmsg.SetDup(false)
|
willmsg.Dup = msg.Dup
|
||||||
} else {
|
} else {
|
||||||
willmsg = nil
|
willmsg = nil
|
||||||
}
|
}
|
||||||
info := info{
|
info := info{
|
||||||
clientID: connMsg.ClientId(),
|
clientID: msg.ClientIdentifier,
|
||||||
username: connMsg.Username(),
|
username: msg.Username,
|
||||||
password: connMsg.Password(),
|
password: msg.Password,
|
||||||
keepalive: connMsg.KeepAlive(),
|
keepalive: msg.Keepalive,
|
||||||
willMsg: willmsg,
|
willMsg: willmsg,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,113 +358,259 @@ func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
|
|||||||
conn: conn,
|
conn: conn,
|
||||||
info: info,
|
info: info,
|
||||||
}
|
}
|
||||||
|
|
||||||
c.init()
|
c.init()
|
||||||
|
|
||||||
var msgPool *MessagePool
|
err = b.getSession(c, msg, connack)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("get session error: ", zap.String("clientID", c.info.clientID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
cid := c.info.clientID
|
||||||
|
|
||||||
var exist bool
|
var exist bool
|
||||||
var old *client
|
var old interface{}
|
||||||
cid := string(c.info.clientID)
|
|
||||||
if typ == CLIENT {
|
|
||||||
old, exist = b.clients.Update(cid, c)
|
|
||||||
msgPool = MSGPool[idx%MessagePoolNum].GetPool()
|
|
||||||
} else if typ == ROUTER {
|
|
||||||
old, exist = b.routes.Update(cid, c)
|
|
||||||
msgPool = MSGPool[MessagePoolNum].GetPool()
|
|
||||||
}
|
|
||||||
if exist {
|
|
||||||
log.Warn("client or routers exists, close old...")
|
|
||||||
old.Close()
|
|
||||||
}
|
|
||||||
c.readLoop(msgPool)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *Broker) ConnectToRouters() {
|
switch typ {
|
||||||
for i := 0; i < len(b.config.Cluster.Routes); i++ {
|
case CLIENT:
|
||||||
url := b.config.Cluster.Routes[i]
|
old, exist = b.clients.Load(cid)
|
||||||
go b.connectRouter(url, "")
|
if exist {
|
||||||
}
|
log.Warn("client exist, close old...", zap.String("clientID", c.info.clientID))
|
||||||
}
|
ol, ok := old.(*client)
|
||||||
|
if ok {
|
||||||
func (b *Broker) connectRouter(url, remoteID string) {
|
ol.Close()
|
||||||
for {
|
|
||||||
conn, err := net.Dial("tcp", url)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Error trying to connect to route: ", err)
|
|
||||||
select {
|
|
||||||
case <-time.After(DEFAULT_ROUTE_CONNECT):
|
|
||||||
log.Debug("Connect to route timeout ,retry...")
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
route := &route{
|
b.clients.Store(cid, c)
|
||||||
remoteID: remoteID,
|
|
||||||
remoteUrl: url,
|
b.OnlineOfflineNotification(cid, true)
|
||||||
|
case ROUTER:
|
||||||
|
old, exist = b.routes.Load(cid)
|
||||||
|
if exist {
|
||||||
|
log.Warn("router exist, close old...")
|
||||||
|
ol, ok := old.(*client)
|
||||||
|
if ok {
|
||||||
|
ol.Close()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
cid := GenUniqueId()
|
b.routes.Store(cid, c)
|
||||||
info := info{
|
|
||||||
clientID: []byte(cid),
|
|
||||||
}
|
|
||||||
c := &client{
|
|
||||||
typ: REMOTE,
|
|
||||||
conn: conn,
|
|
||||||
route: route,
|
|
||||||
info: info,
|
|
||||||
}
|
|
||||||
b.remotes.Set(cid, c)
|
|
||||||
c.SendConnect()
|
|
||||||
c.SendInfo()
|
|
||||||
// s.createRemote(conn, route)
|
|
||||||
msgPool := MSGPool[(MessagePoolNum + 1)].GetPool()
|
|
||||||
c.readLoop(msgPool)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// mpool := b.messagePool[fnv1a.HashString64(cid)%MessagePoolNum]
|
||||||
|
|
||||||
|
c.readLoop()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) ConnectToDiscovery() {
|
||||||
|
var conn net.Conn
|
||||||
|
var err error
|
||||||
|
var tempDelay time.Duration = 0
|
||||||
|
for {
|
||||||
|
conn, err = net.Dial("tcp", b.config.Router)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error trying to connect to route: ", zap.Error(err))
|
||||||
|
log.Debug("Connect to route timeout ,retry...")
|
||||||
|
|
||||||
|
if 0 == tempDelay {
|
||||||
|
tempDelay = 1 * time.Second
|
||||||
|
} else {
|
||||||
|
tempDelay *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if max := 20 * time.Second; tempDelay > max {
|
||||||
|
tempDelay = max
|
||||||
|
}
|
||||||
|
time.Sleep(tempDelay)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
log.Debug("connect to router success :", zap.String("Router", b.config.Router))
|
||||||
|
|
||||||
|
cid := b.id
|
||||||
|
info := info{
|
||||||
|
clientID: cid,
|
||||||
|
keepalive: 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &client{
|
||||||
|
typ: CLUSTER,
|
||||||
|
broker: b,
|
||||||
|
conn: conn,
|
||||||
|
info: info,
|
||||||
|
}
|
||||||
|
|
||||||
|
c.init()
|
||||||
|
|
||||||
|
c.SendConnect()
|
||||||
|
c.SendInfo()
|
||||||
|
|
||||||
|
go c.readLoop()
|
||||||
|
go c.StartPing()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) processClusterInfo() {
|
||||||
|
for {
|
||||||
|
msg, ok := <-b.clusterPool
|
||||||
|
if !ok {
|
||||||
|
log.Error("read message from cluster channel error")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ProcessMessage(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) connectRouter(id, addr string) {
|
||||||
|
var conn net.Conn
|
||||||
|
var err error
|
||||||
|
var timeDelay time.Duration = 0
|
||||||
|
retryTimes := 0
|
||||||
|
max := 32 * time.Second
|
||||||
|
for {
|
||||||
|
|
||||||
|
if !b.checkNodeExist(id, addr) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err = net.Dial("tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error trying to connect to route: ", zap.Error(err))
|
||||||
|
|
||||||
|
if retryTimes > 50 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debug("Connect to route timeout ,retry...")
|
||||||
|
|
||||||
|
if 0 == timeDelay {
|
||||||
|
timeDelay = 1 * time.Second
|
||||||
|
} else {
|
||||||
|
timeDelay *= 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if timeDelay > max {
|
||||||
|
timeDelay = max
|
||||||
|
}
|
||||||
|
time.Sleep(timeDelay)
|
||||||
|
retryTimes++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
route := route{
|
||||||
|
remoteID: id,
|
||||||
|
remoteUrl: addr,
|
||||||
|
}
|
||||||
|
cid := GenUniqueId()
|
||||||
|
|
||||||
|
info := info{
|
||||||
|
clientID: cid,
|
||||||
|
keepalive: 60,
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &client{
|
||||||
|
broker: b,
|
||||||
|
typ: REMOTE,
|
||||||
|
conn: conn,
|
||||||
|
route: route,
|
||||||
|
info: info,
|
||||||
|
}
|
||||||
|
c.init()
|
||||||
|
b.remotes.Store(cid, c)
|
||||||
|
|
||||||
|
c.SendConnect()
|
||||||
|
|
||||||
|
// mpool := b.messagePool[fnv1a.HashString64(cid)%MessagePoolNum]
|
||||||
|
go c.readLoop()
|
||||||
|
go c.StartPing()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *Broker) checkNodeExist(id, url string) bool {
|
||||||
|
if id == b.id {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
for k, v := range b.nodes {
|
||||||
|
if k == id {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
//skip
|
||||||
|
l, ok := v.(string)
|
||||||
|
if ok {
|
||||||
|
if url == l {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) CheckRemoteExist(remoteID, url string) bool {
|
func (b *Broker) CheckRemoteExist(remoteID, url string) bool {
|
||||||
exist := false
|
exist := false
|
||||||
remotes := b.remotes.Items()
|
b.remotes.Range(func(key, value interface{}) bool {
|
||||||
for _, v := range remotes {
|
v, ok := value.(*client)
|
||||||
if v.route.remoteUrl == url {
|
if ok {
|
||||||
// if v.route.remoteID == "" || v.route.remoteID != remoteID {
|
if v.route.remoteUrl == url {
|
||||||
v.route.remoteID = remoteID
|
v.route.remoteID = remoteID
|
||||||
// }
|
exist = true
|
||||||
exist = true
|
return false
|
||||||
break
|
}
|
||||||
}
|
}
|
||||||
}
|
return true
|
||||||
|
})
|
||||||
return exist
|
return exist
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) SendLocalSubsToRouter(c *client) {
|
func (b *Broker) SendLocalSubsToRouter(c *client) {
|
||||||
clients := b.clients.Items()
|
subInfo := packets.NewControlPacket(packets.Subscribe).(*packets.SubscribePacket)
|
||||||
subMsg := message.NewSubscribeMessage()
|
b.clients.Range(func(key, value interface{}) bool {
|
||||||
for _, client := range clients {
|
client, ok := value.(*client)
|
||||||
subs := client.subs
|
if ok {
|
||||||
for _, sub := range subs {
|
subs := client.subMap
|
||||||
subMsg.AddTopic(sub.topic, sub.qos)
|
for _, sub := range subs {
|
||||||
|
subInfo.Topics = append(subInfo.Topics, sub.topic)
|
||||||
|
subInfo.Qoss = append(subInfo.Qoss, sub.qos)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
if len(subInfo.Topics) > 0 {
|
||||||
|
err := c.WriterPacket(subInfo)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Send localsubs To Router error :", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
|
||||||
err := c.writeMessage(subMsg)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Send localsubs To Router error :", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) BroadcastInfoMessage(remoteID string, msg message.Message) {
|
func (b *Broker) BroadcastInfoMessage(remoteID string, msg *packets.PublishPacket) {
|
||||||
remotes := b.remotes.Items()
|
b.routes.Range(func(key, value interface{}) bool {
|
||||||
for _, r := range remotes {
|
r, ok := value.(*client)
|
||||||
if r.route.remoteID == remoteID {
|
if ok {
|
||||||
continue
|
if r.route.remoteID == remoteID {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
r.WriterPacket(msg)
|
||||||
}
|
}
|
||||||
r.writeMessage(msg)
|
return true
|
||||||
}
|
|
||||||
|
})
|
||||||
// log.Info("BroadcastInfoMessage success ")
|
// log.Info("BroadcastInfoMessage success ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) BroadcastSubOrUnsubMessage(buf []byte) {
|
func (b *Broker) BroadcastSubOrUnsubMessage(packet packets.ControlPacket) {
|
||||||
remotes := b.remotes.Items()
|
|
||||||
for _, r := range remotes {
|
b.routes.Range(func(key, value interface{}) bool {
|
||||||
r.writeBuffer(buf)
|
r, ok := value.(*client)
|
||||||
}
|
if ok {
|
||||||
|
r.WriterPacket(packet)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
// log.Info("BroadcastSubscribeMessage remotes: ", s.remotes)
|
// log.Info("BroadcastSubscribeMessage remotes: ", s.remotes)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -327,48 +619,54 @@ func (b *Broker) removeClient(c *client) {
|
|||||||
typ := c.typ
|
typ := c.typ
|
||||||
switch typ {
|
switch typ {
|
||||||
case CLIENT:
|
case CLIENT:
|
||||||
b.clients.Remove(clientId)
|
b.clients.Delete(clientId)
|
||||||
case ROUTER:
|
case ROUTER:
|
||||||
b.routes.Remove(clientId)
|
b.routes.Delete(clientId)
|
||||||
case REMOTE:
|
case REMOTE:
|
||||||
b.remotes.Remove(clientId)
|
b.remotes.Delete(clientId)
|
||||||
}
|
}
|
||||||
// log.Info("delete client ,", clientId)
|
// log.Info("delete client ,", clientId)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *Broker) ProcessPublishMessage(msg *message.PublishMessage) {
|
func (b *Broker) PublishMessage(packet *packets.PublishPacket) {
|
||||||
if b == nil {
|
var subs []interface{}
|
||||||
return
|
var qoss []byte
|
||||||
}
|
b.mu.Lock()
|
||||||
topic := string(msg.Topic())
|
err := b.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &subs, &qoss)
|
||||||
|
b.mu.Unlock()
|
||||||
r := b.sl.Match(topic)
|
if err != nil {
|
||||||
// log.Info("psubs num: ", len(r.psubs))
|
log.Error("search sub client error, ", zap.Error(err))
|
||||||
if len(r.qsubs) == 0 && len(r.psubs) == 0 {
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sub := range r.psubs {
|
for _, sub := range subs {
|
||||||
if sub != nil {
|
s, ok := sub.(*subscription)
|
||||||
err := sub.client.writeMessage(msg)
|
if ok {
|
||||||
|
err := s.client.WriterPacket(packet)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("process message for psub error, ", err)
|
log.Error("write message error, ", zap.Error(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
for i, sub := range r.qsubs {
|
|
||||||
// s.qmu.Lock()
|
func (b *Broker) BroadcastUnSubscribe(subs map[string]*subscription) {
|
||||||
if cnt, exist := b.queues[string(sub.topic)]; exist && i == cnt {
|
|
||||||
if sub != nil {
|
unsub := packets.NewControlPacket(packets.Unsubscribe).(*packets.UnsubscribePacket)
|
||||||
err := sub.client.writeMessage(msg)
|
for topic, _ := range subs {
|
||||||
if err != nil {
|
unsub.Topics = append(unsub.Topics, topic)
|
||||||
log.Error("process will message for qsub error, ", err)
|
}
|
||||||
}
|
|
||||||
}
|
if len(unsub.Topics) > 0 {
|
||||||
b.queues[topic] = (b.queues[topic] + 1) % len(r.qsubs)
|
b.BroadcastSubOrUnsubMessage(unsub)
|
||||||
break
|
}
|
||||||
}
|
}
|
||||||
// s.qmu.Unlock()
|
|
||||||
}
|
func (b *Broker) OnlineOfflineNotification(clientID string, online bool) {
|
||||||
|
packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||||
|
packet.TopicName = "$SYS/broker/connection/clients/" + clientID
|
||||||
|
packet.Qos = 0
|
||||||
|
packet.Payload = []byte(fmt.Sprintf(`{"clientID":"%s","online":%v,"timestamp":"%s"}`, clientID, online, time.Now().UTC().Format(time.RFC3339)))
|
||||||
|
|
||||||
|
b.PublishMessage(packet)
|
||||||
}
|
}
|
||||||
|
|||||||
+259
-282
@@ -1,49 +1,69 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package broker
|
package broker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"hmq/lib/message"
|
|
||||||
"net"
|
"net"
|
||||||
|
"reflect"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/cihub/seelog"
|
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||||
|
"github.com/fhmq/hmq/lib/sessions"
|
||||||
|
"github.com/fhmq/hmq/lib/topics"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// special pub topic for cluster info BrokerInfoTopic
|
// special pub topic for cluster info BrokerInfoTopic
|
||||||
BrokerInfoTopic = "broker001info/brokerinfo"
|
BrokerInfoTopic = "broker000100101info"
|
||||||
// CLIENT is an end user.
|
// CLIENT is an end user.
|
||||||
CLIENT = 0
|
CLIENT = 0
|
||||||
// ROUTER is another router in the cluster.
|
// ROUTER is another router in the cluster.
|
||||||
ROUTER = 1
|
ROUTER = 1
|
||||||
//REMOTE is the router connect to other cluster
|
//REMOTE is the router connect to other cluster
|
||||||
REMOTE = 2
|
REMOTE = 2
|
||||||
|
CLUSTER = 3
|
||||||
|
)
|
||||||
|
const (
|
||||||
|
Connected = 1
|
||||||
|
Disconnected = 2
|
||||||
)
|
)
|
||||||
|
|
||||||
type client struct {
|
type client struct {
|
||||||
typ int
|
typ int
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
broker *Broker
|
broker *Broker
|
||||||
conn net.Conn
|
conn net.Conn
|
||||||
info info
|
info info
|
||||||
route *route
|
route route
|
||||||
subs map[string]*subscription
|
status int
|
||||||
|
ctx context.Context
|
||||||
|
cancelFunc context.CancelFunc
|
||||||
|
session *sessions.Session
|
||||||
|
subMap map[string]*subscription
|
||||||
|
topicsMgr *topics.Manager
|
||||||
|
subs []interface{}
|
||||||
|
qoss []byte
|
||||||
|
rmsgs []*packets.PublishPacket
|
||||||
}
|
}
|
||||||
|
|
||||||
type subscription struct {
|
type subscription struct {
|
||||||
client *client
|
client *client
|
||||||
topic []byte
|
topic string
|
||||||
qos byte
|
qos byte
|
||||||
queue bool
|
queue bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type info struct {
|
type info struct {
|
||||||
clientID []byte
|
clientID string
|
||||||
username []byte
|
username string
|
||||||
password []byte
|
password []byte
|
||||||
keepalive uint16
|
keepalive uint16
|
||||||
willMsg *message.PublishMessage
|
willMsg *packets.PublishPacket
|
||||||
localIP string
|
localIP string
|
||||||
remoteIP string
|
remoteIP string
|
||||||
}
|
}
|
||||||
@@ -53,386 +73,343 @@ type route struct {
|
|||||||
remoteUrl string
|
remoteUrl string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
DisconnectdPacket = packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket)
|
||||||
|
)
|
||||||
|
|
||||||
func (c *client) init() {
|
func (c *client) init() {
|
||||||
c.subs = make(map[string]*subscription, 10)
|
c.status = Connected
|
||||||
c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0]
|
c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0]
|
||||||
c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0]
|
c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0]
|
||||||
|
c.ctx, c.cancelFunc = context.WithCancel(context.Background())
|
||||||
|
c.subMap = make(map[string]*subscription)
|
||||||
|
c.topicsMgr = c.broker.topicsMgr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) readLoop(msgPool *MessagePool) {
|
func (c *client) readLoop() {
|
||||||
nc := c.conn
|
nc := c.conn
|
||||||
if nc == nil || msgPool == nil {
|
b := c.broker
|
||||||
|
if nc == nil || b == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg := &Message{}
|
|
||||||
|
keepAlive := time.Second * time.Duration(c.info.keepalive)
|
||||||
|
timeOut := keepAlive + (keepAlive / 2)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
buf, err := ReadPacket(nc)
|
select {
|
||||||
if err != nil {
|
case <-c.ctx.Done():
|
||||||
log.Error("read packet error: ", err)
|
|
||||||
c.Close()
|
|
||||||
return
|
return
|
||||||
|
default:
|
||||||
|
//add read timeout
|
||||||
|
if err := nc.SetReadDeadline(time.Now().Add(timeOut)); err != nil {
|
||||||
|
log.Error("set read timeout error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
packet, err := packets.ReadPacket(nc)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("read packet error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
|
msg := &Message{client: c, packet: DisconnectdPacket}
|
||||||
|
b.SubmitWork(c.info.clientID, msg)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := &Message{
|
||||||
|
client: c,
|
||||||
|
packet: packet,
|
||||||
|
}
|
||||||
|
b.SubmitWork(c.info.clientID, msg)
|
||||||
}
|
}
|
||||||
msg.client = c
|
|
||||||
msg.msg = buf
|
|
||||||
msgPool.queue <- msg
|
|
||||||
}
|
}
|
||||||
msgPool.Reduce()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ProcessMessage(msg *Message) {
|
func ProcessMessage(msg *Message) {
|
||||||
buf := msg.msg
|
|
||||||
c := msg.client
|
c := msg.client
|
||||||
if c == nil || buf == nil {
|
ca := msg.packet
|
||||||
|
if ca == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msgType := uint8(buf[0] & 0xF0 >> 4)
|
log.Debug("Recv message:", zap.String("message type", reflect.TypeOf(msg.packet).String()[9:]), zap.String("ClientID", c.info.clientID))
|
||||||
switch msgType {
|
switch ca.(type) {
|
||||||
case CONNACK:
|
case *packets.ConnackPacket:
|
||||||
// log.Info("Recv conack message..........")
|
case *packets.ConnectPacket:
|
||||||
c.ProcessConnAck(buf)
|
case *packets.PublishPacket:
|
||||||
case CONNECT:
|
packet := ca.(*packets.PublishPacket)
|
||||||
// log.Info("Recv connect message..........")
|
c.ProcessPublish(packet)
|
||||||
c.ProcessConnect(buf)
|
case *packets.PubackPacket:
|
||||||
case PUBLISH:
|
case *packets.PubrecPacket:
|
||||||
// log.Info("Recv publish message..........")
|
case *packets.PubrelPacket:
|
||||||
c.ProcessPublish(buf)
|
case *packets.PubcompPacket:
|
||||||
case PUBACK:
|
case *packets.SubscribePacket:
|
||||||
//log.Info("Recv publish ack message..........")
|
packet := ca.(*packets.SubscribePacket)
|
||||||
c.ProcessPubAck(buf)
|
c.ProcessSubscribe(packet)
|
||||||
case PUBCOMP:
|
case *packets.SubackPacket:
|
||||||
//log.Info("Recv publish ack message..........")
|
case *packets.UnsubscribePacket:
|
||||||
c.ProcessPubComp(buf)
|
packet := ca.(*packets.UnsubscribePacket)
|
||||||
case PUBREC:
|
c.ProcessUnSubscribe(packet)
|
||||||
//log.Info("Recv publish rec message..........")
|
case *packets.UnsubackPacket:
|
||||||
c.ProcessPubREC(buf)
|
case *packets.PingreqPacket:
|
||||||
case PUBREL:
|
c.ProcessPing()
|
||||||
//log.Info("Recv publish rel message..........")
|
case *packets.PingrespPacket:
|
||||||
c.ProcessPubREL(buf)
|
case *packets.DisconnectPacket:
|
||||||
case SUBSCRIBE:
|
|
||||||
// log.Info("Recv subscribe message.....")
|
|
||||||
c.ProcessSubscribe(buf)
|
|
||||||
case SUBACK:
|
|
||||||
// log.Info("Recv suback message.....")
|
|
||||||
case UNSUBSCRIBE:
|
|
||||||
// log.Info("Recv unsubscribe message.....")
|
|
||||||
c.ProcessUnSubscribe(buf)
|
|
||||||
case UNSUBACK:
|
|
||||||
//log.Info("Recv unsuback message.....")
|
|
||||||
case PINGREQ:
|
|
||||||
// log.Info("Recv PINGREQ message..........")
|
|
||||||
c.ProcessPing(buf)
|
|
||||||
case PINGRESP:
|
|
||||||
//log.Info("Recv PINGRESP message..........")
|
|
||||||
case DISCONNECT:
|
|
||||||
// log.Info("Recv DISCONNECT message.......")
|
|
||||||
c.Close()
|
c.Close()
|
||||||
default:
|
default:
|
||||||
log.Info("Recv Unknow message.......")
|
log.Info("Recv Unknow message.......", zap.String("ClientID", c.info.clientID))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) ProcessConnect(buf []byte) {
|
func (c *client) ProcessPublish(packet *packets.PublishPacket) {
|
||||||
|
if c.status == Disconnected {
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ProcessConnAck(buf []byte) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ProcessPublish(buf []byte) {
|
|
||||||
msg, err := DecodePublishMessage(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Decode Publish Message error: ", err)
|
|
||||||
c.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
topic := msg.Topic()
|
|
||||||
|
|
||||||
if c.typ != CLIENT || !c.CheckTopicAuth(PUB, string(topic)) {
|
topic := packet.TopicName
|
||||||
|
if topic == BrokerInfoTopic && c.typ == CLUSTER {
|
||||||
|
c.ProcessInfo(packet)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
c.ProcessPublishMessage(buf, msg)
|
|
||||||
|
|
||||||
if msg.Retain() {
|
if !c.CheckTopicAuth(PUB, topic) {
|
||||||
if b := c.broker; b != nil {
|
log.Error("Pub Topics Auth failed, ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
|
||||||
err := b.rl.Insert(topic, buf)
|
return
|
||||||
if err != nil {
|
}
|
||||||
log.Error("Insert Retain Message error: ", err)
|
|
||||||
}
|
switch packet.Qos {
|
||||||
|
case QosAtMostOnce:
|
||||||
|
c.ProcessPublishMessage(packet)
|
||||||
|
case QosAtLeastOnce:
|
||||||
|
puback := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
|
||||||
|
puback.MessageID = packet.MessageID
|
||||||
|
if err := c.WriterPacket(puback); err != nil {
|
||||||
|
log.Error("send puback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
c.ProcessPublishMessage(packet)
|
||||||
|
case QosExactlyOnce:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
log.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID))
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) ProcessPublishMessage(buf []byte, msg *message.PublishMessage) {
|
func (c *client) ProcessPublishMessage(packet *packets.PublishPacket) {
|
||||||
|
|
||||||
b := c.broker
|
b := c.broker
|
||||||
if b == nil {
|
if b == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
typ := c.typ
|
typ := c.typ
|
||||||
topic := string(msg.Topic())
|
|
||||||
|
|
||||||
r := b.sl.Match(topic)
|
if packet.Retain {
|
||||||
// log.Info("psubs num: ", len(r.psubs))
|
if err := c.topicsMgr.Retain(packet); err != nil {
|
||||||
if len(r.qsubs) == 0 && len(r.psubs) == 0 {
|
log.Error("Error retaining message: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.mu.Lock()
|
||||||
|
err := c.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &c.subs, &c.qoss)
|
||||||
|
c.mu.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
log.Error("Error retrieving subscribers list: ", zap.String("ClientID", c.info.clientID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, sub := range r.psubs {
|
// log.Info("psubs num: ", len(r.psubs))
|
||||||
if sub.client.typ == ROUTER {
|
if len(c.subs) == 0 {
|
||||||
if typ == ROUTER {
|
return
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if sub != nil {
|
|
||||||
err := sub.client.writeBuffer(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("process message for psub error, ", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, sub := range r.qsubs {
|
for _, sub := range c.subs {
|
||||||
if sub.client.typ == ROUTER {
|
s, ok := sub.(*subscription)
|
||||||
if typ == ROUTER {
|
if ok {
|
||||||
continue
|
if s.client.typ == ROUTER {
|
||||||
}
|
if typ != CLIENT {
|
||||||
}
|
continue
|
||||||
// s.qmu.Lock()
|
|
||||||
if cnt, exist := b.queues[string(sub.topic)]; exist && i == cnt {
|
|
||||||
if sub != nil {
|
|
||||||
err := sub.client.writeBuffer(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("process will message for qsub error, ", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
b.queues[topic] = (b.queues[topic] + 1) % len(r.qsubs)
|
err := s.client.WriterPacket(packet)
|
||||||
break
|
if err != nil {
|
||||||
|
log.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// s.qmu.Unlock()
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ProcessPubAck(buf []byte) {
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) ProcessPubREC(buf []byte) {
|
func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) {
|
||||||
|
if c.status == Disconnected {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ProcessPubREL(buf []byte) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ProcessPubComp(buf []byte) {
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ProcessSubscribe(buf []byte) {
|
|
||||||
b := c.broker
|
b := c.broker
|
||||||
if b == nil {
|
if b == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
msg, err := DecodeSubscribeMessage(buf)
|
topics := packet.Topics
|
||||||
if err != nil {
|
qoss := packet.Qoss
|
||||||
log.Error("Decode Subscribe Message error: ", err)
|
|
||||||
c.Close()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
topics := msg.Topics()
|
|
||||||
qos := msg.Qos()
|
|
||||||
|
|
||||||
suback := message.NewSubackMessage()
|
suback := packets.NewControlPacket(packets.Suback).(*packets.SubackPacket)
|
||||||
suback.SetPacketId(msg.PacketId())
|
suback.MessageID = packet.MessageID
|
||||||
var retcodes []byte
|
var retcodes []byte
|
||||||
|
|
||||||
for i, t := range topics {
|
for i, topic := range topics {
|
||||||
topic := string(t)
|
t := topic
|
||||||
//check topic auth for client
|
//check topic auth for client
|
||||||
if c.typ == CLIENT {
|
if !c.CheckTopicAuth(SUB, topic) {
|
||||||
if !c.CheckTopicAuth(SUB, topic) {
|
log.Error("Sub topic Auth failed: ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
|
||||||
log.Error("CheckSubAuth failed")
|
retcodes = append(retcodes, QosFailure)
|
||||||
retcodes = append(retcodes, message.QosFailure)
|
continue
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if _, exist := c.subs[topic]; !exist {
|
|
||||||
queue := false
|
|
||||||
if strings.HasPrefix(topic, "$queue/") {
|
|
||||||
if len(t) > 7 {
|
|
||||||
t = t[7:]
|
|
||||||
queue = true
|
|
||||||
// b.qmu.Lock()
|
|
||||||
if _, exists := b.queues[topic]; !exists {
|
|
||||||
b.queues[topic] = 0
|
|
||||||
}
|
|
||||||
// b.qmu.Unlock()
|
|
||||||
} else {
|
|
||||||
retcodes = append(retcodes, message.QosFailure)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
sub := &subscription{
|
|
||||||
topic: t,
|
|
||||||
qos: qos[i],
|
|
||||||
client: c,
|
|
||||||
queue: queue,
|
|
||||||
}
|
|
||||||
|
|
||||||
c.mu.Lock()
|
sub := &subscription{
|
||||||
c.subs[topic] = sub
|
topic: t,
|
||||||
c.mu.Unlock()
|
qos: qoss[i],
|
||||||
|
client: c,
|
||||||
err := b.sl.Insert(sub)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Insert subscription error: ", err)
|
|
||||||
retcodes = append(retcodes, message.QosFailure)
|
|
||||||
}
|
|
||||||
retcodes = append(retcodes, qos[i])
|
|
||||||
} else {
|
|
||||||
//if exist ,check whether qos change
|
|
||||||
c.subs[topic].qos = qos[i]
|
|
||||||
retcodes = append(retcodes, qos[i])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
rqos, err := c.topicsMgr.Subscribe([]byte(topic), qoss[i], sub)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.subMap[topic] = sub
|
||||||
|
c.session.AddTopic(topic, qoss[i])
|
||||||
|
retcodes = append(retcodes, rqos)
|
||||||
|
c.topicsMgr.Retained([]byte(topic), &c.rmsgs)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := suback.AddReturnCodes(retcodes); err != nil {
|
suback.ReturnCodes = retcodes
|
||||||
log.Error("add return suback code error, ", err)
|
|
||||||
// if typ == CLIENT {
|
|
||||||
c.Close()
|
|
||||||
// }
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err1 := c.writeMessage(suback)
|
err := c.WriterPacket(suback)
|
||||||
if err1 != nil {
|
if err != nil {
|
||||||
log.Error("send suback error, ", err1)
|
log.Error("send suback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
//broadcast subscribe message
|
//broadcast subscribe message
|
||||||
if c.typ == CLIENT {
|
if c.typ == CLIENT {
|
||||||
go b.BroadcastSubOrUnsubMessage(buf)
|
go b.BroadcastSubOrUnsubMessage(packet)
|
||||||
}
|
}
|
||||||
|
|
||||||
//process retain message
|
//process retain message
|
||||||
for _, t := range topics {
|
for _, rm := range c.rmsgs {
|
||||||
bufs := b.rl.Match(t)
|
if err := c.WriterPacket(rm); err != nil {
|
||||||
for _, buf := range bufs {
|
log.Error("Error publishing retained message:", zap.Any("err", err), zap.String("ClientID", c.info.clientID))
|
||||||
log.Info("process retain message: ", string(buf))
|
} else {
|
||||||
if buf != nil && string(buf) != "" {
|
log.Info("process retain message: ", zap.Any("packet", packet), zap.String("ClientID", c.info.clientID))
|
||||||
c.writeBuffer(buf)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) ProcessUnSubscribe(buf []byte) {
|
func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) {
|
||||||
|
if c.status == Disconnected {
|
||||||
|
return
|
||||||
|
}
|
||||||
b := c.broker
|
b := c.broker
|
||||||
if b == nil {
|
if b == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
topics := packet.Topics
|
||||||
|
|
||||||
unsub, err := DecodeUnsubscribeMessage(buf)
|
for _, topic := range topics {
|
||||||
if err != nil {
|
t := []byte(topic)
|
||||||
log.Error("Decode UnSubscribe Message error: ", err)
|
sub, exist := c.subMap[topic]
|
||||||
c.Close()
|
if exist {
|
||||||
return
|
c.topicsMgr.Unsubscribe(t, sub)
|
||||||
}
|
c.session.RemoveTopic(topic)
|
||||||
topics := unsub.Topics()
|
delete(c.subMap, topic)
|
||||||
|
|
||||||
for _, t := range topics {
|
|
||||||
var sub *subscription
|
|
||||||
ok := false
|
|
||||||
|
|
||||||
if sub, ok = c.subs[string(t)]; ok {
|
|
||||||
go c.unsubscribe(sub)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
resp := message.NewUnsubackMessage()
|
unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket)
|
||||||
resp.SetPacketId(unsub.PacketId())
|
unsuback.MessageID = packet.MessageID
|
||||||
|
|
||||||
err1 := c.writeMessage(resp)
|
err := c.WriterPacket(unsuback)
|
||||||
if err1 != nil {
|
if err != nil {
|
||||||
log.Error("send ubsuback error, ", err1)
|
log.Error("send unsuback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// //process ubsubscribe message
|
// //process ubsubscribe message
|
||||||
if c.typ == CLIENT {
|
if c.typ == CLIENT {
|
||||||
b.BroadcastSubOrUnsubMessage(buf)
|
b.BroadcastSubOrUnsubMessage(packet)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) unsubscribe(sub *subscription) {
|
func (c *client) ProcessPing() {
|
||||||
|
if c.status == Disconnected {
|
||||||
c.mu.Lock()
|
|
||||||
delete(c.subs, string(sub.topic))
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
if c.broker != nil {
|
|
||||||
c.broker.sl.Remove(sub)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *client) ProcessPing(buf []byte) {
|
|
||||||
_, err := DecodePingreqMessage(buf)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("Decode PingRequest Message error: ", err)
|
|
||||||
c.Close()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
resp := packets.NewControlPacket(packets.Pingresp).(*packets.PingrespPacket)
|
||||||
pingRspMsg := message.NewPingrespMessage()
|
err := c.WriterPacket(resp)
|
||||||
err = c.writeMessage(pingRspMsg)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("send PingResponse error, ", err)
|
log.Error("send PingResponse error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) Close() {
|
func (c *client) Close() {
|
||||||
b := c.broker
|
if c.status == Disconnected {
|
||||||
subs := c.subs
|
return
|
||||||
if b != nil {
|
|
||||||
b.removeClient(c)
|
|
||||||
for _, sub := range subs {
|
|
||||||
err := b.sl.Remove(sub)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("closed client but remove sublist error, ", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if c.info.willMsg != nil {
|
|
||||||
b.ProcessPublishMessage(c.info.willMsg)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.cancelFunc()
|
||||||
|
|
||||||
|
c.status = Disconnected
|
||||||
|
//wait for message complete
|
||||||
|
// time.Sleep(1 * time.Second)
|
||||||
|
// c.status = Disconnected
|
||||||
|
|
||||||
if c.conn != nil {
|
if c.conn != nil {
|
||||||
c.conn.Close()
|
c.conn.Close()
|
||||||
c.conn = nil
|
c.conn = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
b := c.broker
|
||||||
|
subs := c.subMap
|
||||||
|
if b != nil {
|
||||||
|
b.removeClient(c)
|
||||||
|
|
||||||
|
if c.typ == CLIENT {
|
||||||
|
b.BroadcastUnSubscribe(subs)
|
||||||
|
//offline notification
|
||||||
|
b.OnlineOfflineNotification(c.info.clientID, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.info.willMsg != nil {
|
||||||
|
b.PublishMessage(c.info.willMsg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.typ == CLUSTER {
|
||||||
|
b.ConnectToDiscovery()
|
||||||
|
}
|
||||||
|
|
||||||
|
//do reconnect
|
||||||
|
if c.typ == REMOTE {
|
||||||
|
go b.connectRouter(c.route.remoteID, c.route.remoteUrl)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WriteBuffer(conn net.Conn, buf []byte) error {
|
func (c *client) WriterPacket(packet packets.ControlPacket) error {
|
||||||
if conn == nil {
|
if c.status == Disconnected {
|
||||||
return errors.New("conn is nul")
|
return nil
|
||||||
}
|
}
|
||||||
_, err := conn.Write(buf)
|
|
||||||
return err
|
if packet == nil {
|
||||||
}
|
return nil
|
||||||
func (c *client) writeBuffer(buf []byte) error {
|
}
|
||||||
|
if c.conn == nil {
|
||||||
|
c.Close()
|
||||||
|
return errors.New("connect lost ....")
|
||||||
|
}
|
||||||
|
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
err := WriteBuffer(c.conn, buf)
|
err := packet.Write(c.conn)
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) writeMessage(msg message.Message) error {
|
|
||||||
buf, err := EncodeMessage(msg)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return c.writeBuffer(buf)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,73 +0,0 @@
|
|||||||
package broker
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
type cMap interface {
|
|
||||||
Set(key string, val *client)
|
|
||||||
Get(key string) (*client, bool)
|
|
||||||
Items() map[string]*client
|
|
||||||
Exist(key string) bool
|
|
||||||
Update(key string, val *client) (*client, bool)
|
|
||||||
Count() int
|
|
||||||
Remove(key string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type clientMap struct {
|
|
||||||
items map[string]*client
|
|
||||||
mu sync.RWMutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewClientMap() cMap {
|
|
||||||
smap := &clientMap{
|
|
||||||
items: make(map[string]*client),
|
|
||||||
}
|
|
||||||
return smap
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *clientMap) Set(key string, val *client) {
|
|
||||||
s.mu.Lock()
|
|
||||||
s.items[key] = val
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *clientMap) Get(key string) (*client, bool) {
|
|
||||||
s.mu.RLock()
|
|
||||||
val, ok := s.items[key]
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return val, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *clientMap) Exist(key string) bool {
|
|
||||||
s.mu.RLock()
|
|
||||||
_, ok := s.items[key]
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *clientMap) Update(key string, val *client) (*client, bool) {
|
|
||||||
s.mu.Lock()
|
|
||||||
old, ok := s.items[key]
|
|
||||||
s.items[key] = val
|
|
||||||
s.mu.Unlock()
|
|
||||||
return old, ok
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *clientMap) Count() int {
|
|
||||||
s.mu.RLock()
|
|
||||||
len := len(s.items)
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return len
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *clientMap) Remove(key string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
delete(s.items, key)
|
|
||||||
s.mu.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *clientMap) Items() map[string]*client {
|
|
||||||
s.mu.RLock()
|
|
||||||
items := s.items
|
|
||||||
s.mu.RUnlock()
|
|
||||||
return items
|
|
||||||
}
|
|
||||||
+8
-51
@@ -1,15 +1,14 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package broker
|
package broker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"crypto/md5"
|
"crypto/md5"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/base64"
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"errors"
|
|
||||||
"io"
|
"io"
|
||||||
"reflect"
|
"reflect"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,54 +39,12 @@ const (
|
|||||||
PINGRESP
|
PINGRESP
|
||||||
DISCONNECT
|
DISCONNECT
|
||||||
)
|
)
|
||||||
|
const (
|
||||||
func SubscribeTopicCheckAndSpilt(subject []byte) ([]string, error) {
|
QosAtMostOnce byte = iota
|
||||||
|
QosAtLeastOnce
|
||||||
topic := string(subject)
|
QosExactlyOnce
|
||||||
|
QosFailure = 0x80
|
||||||
if bytes.IndexByte(subject, '#') != -1 {
|
)
|
||||||
if bytes.IndexByte(subject, '#') != len(subject)-1 {
|
|
||||||
return nil, errors.New("Topic format error with index of #")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
re := strings.Split(topic, "/")
|
|
||||||
for i, v := range re {
|
|
||||||
if i != 0 && i != (len(re)-1) {
|
|
||||||
if v == "" {
|
|
||||||
return nil, errors.New("Topic format error with index of //")
|
|
||||||
}
|
|
||||||
if strings.Contains(v, "+") && v != "+" {
|
|
||||||
return nil, errors.New("Topic format error with index of +")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if v == "" {
|
|
||||||
re[i] = "/"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return re, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func PublishTopicCheckAndSpilt(subject []byte) ([]string, error) {
|
|
||||||
if bytes.IndexByte(subject, '#') != -1 || bytes.IndexByte(subject, '+') != -1 {
|
|
||||||
return nil, errors.New("Publish Topic format error with + and #")
|
|
||||||
}
|
|
||||||
topic := string(subject)
|
|
||||||
re := strings.Split(topic, "/")
|
|
||||||
for i, v := range re {
|
|
||||||
if v == "" {
|
|
||||||
if i != 0 && i != (len(re)-1) {
|
|
||||||
return nil, errors.New("Topic format error with index of //")
|
|
||||||
} else {
|
|
||||||
re[i] = "/"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
return re, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func equal(k1, k2 interface{}) bool {
|
func equal(k1, k2 interface{}) bool {
|
||||||
if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
|
if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
|
||||||
|
|||||||
+117
-17
@@ -1,23 +1,27 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package broker
|
package broker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
"crypto/x509"
|
"crypto/x509"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
|
||||||
log "github.com/cihub/seelog"
|
"github.com/fhmq/hmq/logger"
|
||||||
)
|
"go.uber.org/zap"
|
||||||
|
|
||||||
const (
|
|
||||||
CONFIGFILE = "broker.config"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
Worker int `json:"workerNum"`
|
||||||
Host string `json:"host"`
|
Host string `json:"host"`
|
||||||
Port string `json:"port"`
|
Port string `json:"port"`
|
||||||
Cluster RouteInfo `json:"cluster"`
|
Cluster RouteInfo `json:"cluster"`
|
||||||
|
Router string `json:"router"`
|
||||||
TlsHost string `json:"tlsHost"`
|
TlsHost string `json:"tlsHost"`
|
||||||
TlsPort string `json:"tlsPort"`
|
TlsPort string `json:"tlsPort"`
|
||||||
WsPath string `json:"wsPath"`
|
WsPath string `json:"wsPath"`
|
||||||
@@ -26,12 +30,12 @@ type Config struct {
|
|||||||
TlsInfo TLSInfo `json:"tlsInfo"`
|
TlsInfo TLSInfo `json:"tlsInfo"`
|
||||||
Acl bool `json:"acl"`
|
Acl bool `json:"acl"`
|
||||||
AclConf string `json:"aclConf"`
|
AclConf string `json:"aclConf"`
|
||||||
|
Debug bool `json:"-"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type RouteInfo struct {
|
type RouteInfo struct {
|
||||||
Host string `json:"host"`
|
Host string `json:"host"`
|
||||||
Port string `json:"port"`
|
Port string `json:"port"`
|
||||||
Routes []string `json:"routes"`
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type TLSInfo struct {
|
type TLSInfo struct {
|
||||||
@@ -41,11 +45,94 @@ type TLSInfo struct {
|
|||||||
KeyFile string `json:"keyFile"`
|
KeyFile string `json:"keyFile"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func LoadConfig() (*Config, error) {
|
var DefaultConfig *Config = &Config{
|
||||||
|
Worker: 4096,
|
||||||
|
Host: "0.0.0.0",
|
||||||
|
Port: "1883",
|
||||||
|
Acl: false,
|
||||||
|
}
|
||||||
|
|
||||||
content, err := ioutil.ReadFile(CONFIGFILE)
|
var (
|
||||||
|
log *zap.Logger
|
||||||
|
)
|
||||||
|
|
||||||
|
func showHelp() {
|
||||||
|
fmt.Printf("%s\n", usageStr)
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConfigureConfig(args []string) (*Config, error) {
|
||||||
|
config := &Config{}
|
||||||
|
var (
|
||||||
|
help bool
|
||||||
|
configFile string
|
||||||
|
)
|
||||||
|
fs := flag.NewFlagSet("hmq-broker", flag.ExitOnError)
|
||||||
|
fs.Usage = showHelp
|
||||||
|
|
||||||
|
fs.BoolVar(&help, "h", false, "Show this message.")
|
||||||
|
fs.BoolVar(&help, "help", false, "Show this message.")
|
||||||
|
fs.IntVar(&config.Worker, "w", 1024, "worker num to process message, perfer (client num)/10.")
|
||||||
|
fs.IntVar(&config.Worker, "worker", 1024, "worker num to process message, perfer (client num)/10.")
|
||||||
|
fs.StringVar(&config.Port, "port", "1883", "Port to listen on.")
|
||||||
|
fs.StringVar(&config.Port, "p", "1883", "Port to listen on.")
|
||||||
|
fs.StringVar(&config.Host, "host", "0.0.0.0", "Network host to listen on")
|
||||||
|
fs.StringVar(&config.Cluster.Port, "cp", "", "Cluster port from which members can connect.")
|
||||||
|
fs.StringVar(&config.Cluster.Port, "clusterport", "", "Cluster port from which members can connect.")
|
||||||
|
fs.StringVar(&config.Router, "r", "", "Router who maintenance cluster info")
|
||||||
|
fs.StringVar(&config.Router, "router", "", "Router who maintenance cluster info")
|
||||||
|
fs.StringVar(&config.WsPort, "ws", "", "port for ws to listen on")
|
||||||
|
fs.StringVar(&config.WsPort, "wsport", "", "port for ws to listen on")
|
||||||
|
fs.StringVar(&config.WsPath, "wsp", "", "path for ws to listen on")
|
||||||
|
fs.StringVar(&config.WsPath, "wspath", "", "path for ws to listen on")
|
||||||
|
fs.StringVar(&configFile, "config", "", "config file for hmq")
|
||||||
|
fs.StringVar(&configFile, "c", "", "config file for hmq")
|
||||||
|
fs.BoolVar(&config.Debug, "debug", false, "enable Debug logging.")
|
||||||
|
fs.BoolVar(&config.Debug, "d", false, "enable Debug logging.")
|
||||||
|
|
||||||
|
fs.Bool("D", true, "enable Debug logging.")
|
||||||
|
|
||||||
|
if err := fs.Parse(args); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if help {
|
||||||
|
showHelp()
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fs.Visit(func(f *flag.Flag) {
|
||||||
|
switch f.Name {
|
||||||
|
case "D":
|
||||||
|
config.Debug = true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.InitLogger(config.Debug)
|
||||||
|
log = logger.Get().Named("Broker")
|
||||||
|
|
||||||
|
if configFile != "" {
|
||||||
|
tmpConfig, e := LoadConfig(configFile)
|
||||||
|
if e != nil {
|
||||||
|
return nil, e
|
||||||
|
} else {
|
||||||
|
config = tmpConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := config.check(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return config, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadConfig(filename string) (*Config, error) {
|
||||||
|
|
||||||
|
content, err := ioutil.ReadFile(filename)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Read config file error: ", err)
|
log.Error("Read config file error: ", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
// log.Info(string(content))
|
// log.Info(string(content))
|
||||||
@@ -53,10 +140,19 @@ func LoadConfig() (*Config, error) {
|
|||||||
var config Config
|
var config Config
|
||||||
err = json.Unmarshal(content, &config)
|
err = json.Unmarshal(content, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("Unmarshal config file error: ", err)
|
log.Error("Unmarshal config file error: ", zap.Error(err))
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return &config, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (config *Config) check() error {
|
||||||
|
|
||||||
|
if config.Worker == 0 {
|
||||||
|
config.Worker = 1024
|
||||||
|
}
|
||||||
|
|
||||||
if config.Port != "" {
|
if config.Port != "" {
|
||||||
if config.Host == "" {
|
if config.Host == "" {
|
||||||
config.Host = "0.0.0.0"
|
config.Host = "0.0.0.0"
|
||||||
@@ -68,29 +164,33 @@ func LoadConfig() (*Config, error) {
|
|||||||
config.Cluster.Host = "0.0.0.0"
|
config.Cluster.Host = "0.0.0.0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if config.Router != "" {
|
||||||
|
if config.Cluster.Port == "" {
|
||||||
|
return errors.New("cluster port is null")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if config.TlsPort != "" {
|
if config.TlsPort != "" {
|
||||||
if config.TlsInfo.CertFile == "" || config.TlsInfo.KeyFile == "" {
|
if config.TlsInfo.CertFile == "" || config.TlsInfo.KeyFile == "" {
|
||||||
log.Error("tls config error, no cert or key file.")
|
log.Error("tls config error, no cert or key file.")
|
||||||
return nil, err
|
return errors.New("tls config error, no cert or key file.")
|
||||||
}
|
}
|
||||||
if config.TlsHost == "" {
|
if config.TlsHost == "" {
|
||||||
config.TlsHost = "0.0.0.0"
|
config.TlsHost = "0.0.0.0"
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
return nil
|
||||||
return &config, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTLSConfig(tlsInfo TLSInfo) (*tls.Config, error) {
|
func NewTLSConfig(tlsInfo TLSInfo) (*tls.Config, error) {
|
||||||
|
|
||||||
cert, err := tls.LoadX509KeyPair(tlsInfo.CertFile, tlsInfo.KeyFile)
|
cert, err := tls.LoadX509KeyPair(tlsInfo.CertFile, tlsInfo.KeyFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err)
|
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", zap.Error(err))
|
||||||
}
|
}
|
||||||
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
return nil, fmt.Errorf("error parsing certificate: %v", zap.Error(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create TLSConfig
|
// Create TLSConfig
|
||||||
|
|||||||
@@ -1,46 +0,0 @@
|
|||||||
package broker
|
|
||||||
|
|
||||||
const (
|
|
||||||
WorkNum = 2048
|
|
||||||
)
|
|
||||||
|
|
||||||
type Dispatcher struct {
|
|
||||||
WorkerPool chan chan *Message
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
InitMessagePool()
|
|
||||||
dispatcher := NewDispatcher()
|
|
||||||
dispatcher.Run()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dispatcher) Run() {
|
|
||||||
// starting n number of workers
|
|
||||||
for i := 0; i < WorkNum; i++ {
|
|
||||||
worker := NewWorker(d.WorkerPool)
|
|
||||||
worker.Start()
|
|
||||||
}
|
|
||||||
go d.dispatch()
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDispatcher() *Dispatcher {
|
|
||||||
pool := make(chan chan *Message, WorkNum)
|
|
||||||
return &Dispatcher{WorkerPool: pool}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *Dispatcher) dispatch() {
|
|
||||||
for i := 0; i < MessagePoolNum; i++ {
|
|
||||||
go func(idx int) {
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case msg := <-MSGPool[idx].queue:
|
|
||||||
go func(msg *Message) {
|
|
||||||
msgChannel := <-d.WorkerPool
|
|
||||||
msgChannel <- msg
|
|
||||||
}(msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}(i)
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
+57
-57
@@ -1,113 +1,113 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package broker
|
package broker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"hmq/lib/message"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
simplejson "github.com/bitly/go-simplejson"
|
simplejson "github.com/bitly/go-simplejson"
|
||||||
log "github.com/cihub/seelog"
|
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||||
|
"go.uber.org/zap"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (c *client) SendInfo() {
|
func (c *client) SendInfo() {
|
||||||
|
if c.status == Disconnected {
|
||||||
|
return
|
||||||
|
}
|
||||||
url := c.info.localIP + ":" + c.broker.config.Cluster.Port
|
url := c.info.localIP + ":" + c.broker.config.Cluster.Port
|
||||||
|
|
||||||
infoMsg := NewInfo(c.broker.id, url, false)
|
infoMsg := NewInfo(c.broker.id, url, false)
|
||||||
err := c.writeMessage(infoMsg)
|
err := c.WriterPacket(infoMsg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("send info message error, ", err)
|
log.Error("send info message error, ", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// log.Info("send info success")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) StartPing() {
|
func (c *client) StartPing() {
|
||||||
timeTicker := time.NewTicker(time.Second * 30)
|
timeTicker := time.NewTicker(time.Second * 50)
|
||||||
ping := message.NewPingreqMessage()
|
ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-timeTicker.C:
|
case <-timeTicker.C:
|
||||||
err := c.writeMessage(ping)
|
err := c.WriterPacket(ping)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("ping error: ", err)
|
log.Error("ping error: ", zap.Error(err))
|
||||||
|
c.Close()
|
||||||
}
|
}
|
||||||
|
case <-c.ctx.Done():
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) SendConnect() {
|
func (c *client) SendConnect() {
|
||||||
|
|
||||||
clientID := c.info.clientID
|
if c.status != Connected {
|
||||||
connMsg := message.NewConnectMessage()
|
|
||||||
connMsg.SetClientId(clientID)
|
|
||||||
connMsg.SetVersion(0x04)
|
|
||||||
err := c.writeMessage(connMsg)
|
|
||||||
if err != nil {
|
|
||||||
log.Error("send connect message error, ", err)
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// log.Info("send connet success")
|
m := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
|
||||||
|
|
||||||
|
m.CleanSession = true
|
||||||
|
m.ClientIdentifier = c.info.clientID
|
||||||
|
m.Keepalive = uint16(60)
|
||||||
|
err := c.WriterPacket(m)
|
||||||
|
if err != nil {
|
||||||
|
log.Error("send connect message error, ", zap.Error(err))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Info("send connect success")
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewInfo(sid, url string, isforword bool) *message.PublishMessage {
|
func NewInfo(sid, url string, isforword bool) *packets.PublishPacket {
|
||||||
infoMsg := message.NewPublishMessage()
|
pub := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||||
infoMsg.SetTopic([]byte(BrokerInfoTopic))
|
pub.Qos = 0
|
||||||
info := fmt.Sprintf(`{"remoteID":"%s","url":"%s","isForward":%t}`, sid, url, isforword)
|
pub.TopicName = BrokerInfoTopic
|
||||||
|
pub.Retain = false
|
||||||
|
info := fmt.Sprintf(`{"brokerID":"%s","brokerUrl":"%s"}`, sid, url)
|
||||||
// log.Info("new info", string(info))
|
// log.Info("new info", string(info))
|
||||||
infoMsg.SetPayload([]byte(info))
|
pub.Payload = []byte(info)
|
||||||
infoMsg.SetQoS(0)
|
return pub
|
||||||
infoMsg.SetRetain(false)
|
|
||||||
return infoMsg
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *client) ProcessInfo(msg *message.PublishMessage) {
|
func (c *client) ProcessInfo(packet *packets.PublishPacket) {
|
||||||
nc := c.conn
|
nc := c.conn
|
||||||
b := c.broker
|
b := c.broker
|
||||||
if nc == nil {
|
if nc == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("recv remoteInfo: ", string(msg.Payload()))
|
log.Info("recv remoteInfo: ", zap.String("payload", string(packet.Payload)))
|
||||||
|
|
||||||
js, e := simplejson.NewJson(msg.Payload())
|
js, err := simplejson.NewJson(packet.Payload)
|
||||||
if e != nil {
|
if err != nil {
|
||||||
log.Warn("parse info message err", e)
|
log.Warn("parse info message err", zap.Error(err))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
rid := js.Get("remoteID").MustString()
|
routes, err := js.Get("data").Map()
|
||||||
rurl := js.Get("url").MustString()
|
if routes == nil {
|
||||||
isForward := js.Get("isForward").MustBool()
|
log.Error("receive info message error, ", zap.Error(err))
|
||||||
|
|
||||||
if rid == "" {
|
|
||||||
log.Error("receive info message error with remoteID is null")
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if rid == b.id {
|
b.nodes = routes
|
||||||
if !isForward {
|
|
||||||
c.Close() //close connet self
|
b.mu.Lock()
|
||||||
|
for rid, rurl := range routes {
|
||||||
|
if rid == b.id {
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
exist := b.CheckRemoteExist(rid, rurl)
|
url, ok := rurl.(string)
|
||||||
if !exist {
|
if ok {
|
||||||
go b.connectRouter(rurl, rid)
|
exist := b.CheckRemoteExist(rid, url)
|
||||||
}
|
if !exist {
|
||||||
// log.Info("isforword: ", isForward)
|
b.connectRouter(rid, url)
|
||||||
if !isForward {
|
}
|
||||||
route := &route{
|
|
||||||
remoteUrl: rurl,
|
|
||||||
remoteID: rid,
|
|
||||||
}
|
}
|
||||||
c.route = route
|
|
||||||
|
|
||||||
go b.SendLocalSubsToRouter(c)
|
|
||||||
// log.Info("BroadcastInfoMessage starting... ")
|
|
||||||
infoMsg := NewInfo(rid, rurl, true)
|
|
||||||
b.BroadcastInfoMessage(rid, infoMsg)
|
|
||||||
}
|
}
|
||||||
|
b.mu.Unlock()
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,58 +0,0 @@
|
|||||||
package broker
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
const (
|
|
||||||
MaxUser = 1024 * 1024
|
|
||||||
MessagePoolNum = 1024
|
|
||||||
MessagePoolUser = MaxUser / MessagePoolNum
|
|
||||||
MessagePoolMessageNum = MaxUser / MessagePoolNum * 4
|
|
||||||
)
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
client *client
|
|
||||||
msg []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
MSGPool []MessagePool
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessagePool struct {
|
|
||||||
l sync.Mutex
|
|
||||||
maxuser int
|
|
||||||
user int
|
|
||||||
queue chan *Message
|
|
||||||
}
|
|
||||||
|
|
||||||
func InitMessagePool() {
|
|
||||||
MSGPool = make([]MessagePool, (MessagePoolNum + 2))
|
|
||||||
for i := 0; i < (MessagePoolNum + 2); i++ {
|
|
||||||
MSGPool[i].Init(MessagePoolUser, MessagePoolMessageNum)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MessagePool) Init(num int, maxusernum int) {
|
|
||||||
p.maxuser = maxusernum
|
|
||||||
p.queue = make(chan *Message, num)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MessagePool) GetPool() *MessagePool {
|
|
||||||
p.l.Lock()
|
|
||||||
if p.user+1 < p.maxuser {
|
|
||||||
p.user += 1
|
|
||||||
p.l.Unlock()
|
|
||||||
return p
|
|
||||||
} else {
|
|
||||||
p.l.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *MessagePool) Reduce() {
|
|
||||||
p.l.Lock()
|
|
||||||
p.user -= 1
|
|
||||||
p.l.Unlock()
|
|
||||||
|
|
||||||
}
|
|
||||||
@@ -1,236 +0,0 @@
|
|||||||
package broker
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"hmq/lib/message"
|
|
||||||
"io"
|
|
||||||
"net"
|
|
||||||
|
|
||||||
log "github.com/cihub/seelog"
|
|
||||||
)
|
|
||||||
|
|
||||||
func checkError(desc string, err error) {
|
|
||||||
if err != nil {
|
|
||||||
log.Error(desc, " : ", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func ReadPacket(conn net.Conn) ([]byte, error) {
|
|
||||||
if conn == nil {
|
|
||||||
return nil, errors.New("conn is null")
|
|
||||||
}
|
|
||||||
// conn.SetReadDeadline(t)
|
|
||||||
var buf []byte
|
|
||||||
// read fix header
|
|
||||||
b := make([]byte, 1)
|
|
||||||
_, err := io.ReadFull(conn, b)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
buf = append(buf, b...)
|
|
||||||
// read rem msg length
|
|
||||||
rembuf, remlen := decodeLength(conn)
|
|
||||||
buf = append(buf, rembuf...)
|
|
||||||
// read rem msg
|
|
||||||
packetBytes := make([]byte, remlen)
|
|
||||||
_, err = io.ReadFull(conn, packetBytes)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
buf = append(buf, packetBytes...)
|
|
||||||
// log.Info("len buf: ", len(buf))
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func decodeLength(r io.Reader) ([]byte, int) {
|
|
||||||
var rLength uint32
|
|
||||||
var multiplier uint32
|
|
||||||
var buf []byte
|
|
||||||
b := make([]byte, 1)
|
|
||||||
for {
|
|
||||||
io.ReadFull(r, b)
|
|
||||||
digit := b[0]
|
|
||||||
buf = append(buf, b[0])
|
|
||||||
rLength |= uint32(digit&127) << multiplier
|
|
||||||
if (digit & 128) == 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
multiplier += 7
|
|
||||||
|
|
||||||
}
|
|
||||||
return buf, int(rLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeMessage(buf []byte) (message.Message, error) {
|
|
||||||
msgType := uint8(buf[0] & 0xF0 >> 4)
|
|
||||||
switch msgType {
|
|
||||||
case CONNECT:
|
|
||||||
return DecodeConnectMessage(buf)
|
|
||||||
case CONNACK:
|
|
||||||
return DecodeConnackMessage(buf)
|
|
||||||
case PUBLISH:
|
|
||||||
return DecodePublishMessage(buf)
|
|
||||||
case PUBACK:
|
|
||||||
return DecodePubackMessage(buf)
|
|
||||||
case PUBCOMP:
|
|
||||||
return DecodePubcompMessage(buf)
|
|
||||||
case PUBREC:
|
|
||||||
return DecodePubrecMessage(buf)
|
|
||||||
case PUBREL:
|
|
||||||
return DecodePubrelMessage(buf)
|
|
||||||
case SUBSCRIBE:
|
|
||||||
return DecodeSubscribeMessage(buf)
|
|
||||||
case SUBACK:
|
|
||||||
return DecodeSubackMessage(buf)
|
|
||||||
case UNSUBSCRIBE:
|
|
||||||
return DecodeUnsubscribeMessage(buf)
|
|
||||||
case UNSUBACK:
|
|
||||||
return DecodeUnsubackMessage(buf)
|
|
||||||
case PINGREQ:
|
|
||||||
return DecodePingreqMessage(buf)
|
|
||||||
case PINGRESP:
|
|
||||||
return DecodePingrespMessage(buf)
|
|
||||||
case DISCONNECT:
|
|
||||||
return DecodeDisconnectMessage(buf)
|
|
||||||
default:
|
|
||||||
return nil, errors.New("error message type")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeConnectMessage(buf []byte) (*message.ConnectMessage, error) {
|
|
||||||
connMsg := message.NewConnectMessage()
|
|
||||||
_, err := connMsg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
if !message.ValidConnackError(err) {
|
|
||||||
return nil, errors.New("Connect message format error, " + err.Error())
|
|
||||||
}
|
|
||||||
return nil, errors.New("Deode connect message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return connMsg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeConnackMessage(buf []byte) (*message.ConnackMessage, error) {
|
|
||||||
msg := message.NewConnackMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Connack message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodePublishMessage(buf []byte) (*message.PublishMessage, error) {
|
|
||||||
msg := message.NewPublishMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Publish message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodePubackMessage(buf []byte) (*message.PubackMessage, error) {
|
|
||||||
msg := message.NewPubackMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Puback message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodePubrecMessage(buf []byte) (*message.PubrecMessage, error) {
|
|
||||||
msg := message.NewPubrecMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Pubrec message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodePubrelMessage(buf []byte) (*message.PubrelMessage, error) {
|
|
||||||
msg := message.NewPubrelMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Pubrel message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodePubcompMessage(buf []byte) (*message.PubcompMessage, error) {
|
|
||||||
msg := message.NewPubcompMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Pubcomp message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeSubscribeMessage(buf []byte) (*message.SubscribeMessage, error) {
|
|
||||||
msg := message.NewSubscribeMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Subscribe message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeSubackMessage(buf []byte) (*message.SubackMessage, error) {
|
|
||||||
msg := message.NewSubackMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Suback message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeUnsubscribeMessage(buf []byte) (*message.UnsubscribeMessage, error) {
|
|
||||||
msg := message.NewUnsubscribeMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Unsubscribe message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeUnsubackMessage(buf []byte) (*message.UnsubackMessage, error) {
|
|
||||||
msg := message.NewUnsubackMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Unsuback message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodePingreqMessage(buf []byte) (*message.PingreqMessage, error) {
|
|
||||||
msg := message.NewPingreqMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Pingreq message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodePingrespMessage(buf []byte) (*message.PingrespMessage, error) {
|
|
||||||
msg := message.NewPingrespMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Pingresp message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func DecodeDisconnectMessage(buf []byte) (*message.DisconnectMessage, error) {
|
|
||||||
msg := message.NewDisconnectMessage()
|
|
||||||
_, err := msg.Decode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.New("Decode Disconnect message error, " + err.Error())
|
|
||||||
}
|
|
||||||
return msg, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func EncodeMessage(msg message.Message) ([]byte, error) {
|
|
||||||
buf := make([]byte, msg.Len())
|
|
||||||
_, err := msg.Encode(buf)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return buf, nil
|
|
||||||
}
|
|
||||||
@@ -1,120 +0,0 @@
|
|||||||
package broker
|
|
||||||
|
|
||||||
import (
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
type RetainList struct {
|
|
||||||
sync.RWMutex
|
|
||||||
root *rlevel
|
|
||||||
}
|
|
||||||
type rlevel struct {
|
|
||||||
nodes map[string]*rnode
|
|
||||||
}
|
|
||||||
type rnode struct {
|
|
||||||
next *rlevel
|
|
||||||
msg []byte
|
|
||||||
}
|
|
||||||
type RetainResult struct {
|
|
||||||
msg [][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRNode() *rnode {
|
|
||||||
return &rnode{msg: make([]byte, 0, 4)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRLevel() *rlevel {
|
|
||||||
return &rlevel{nodes: make(map[string]*rnode)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRetainList() *RetainList {
|
|
||||||
return &RetainList{root: newRLevel()}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RetainList) Insert(topic, buf []byte) error {
|
|
||||||
|
|
||||||
tokens, err := PublishTopicCheckAndSpilt(topic)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
// log.Info("insert tokens:", tokens)
|
|
||||||
r.Lock()
|
|
||||||
|
|
||||||
l := r.root
|
|
||||||
var n *rnode
|
|
||||||
for _, t := range tokens {
|
|
||||||
n = l.nodes[t]
|
|
||||||
if n == nil {
|
|
||||||
n = newRNode()
|
|
||||||
l.nodes[t] = n
|
|
||||||
}
|
|
||||||
if n.next == nil {
|
|
||||||
n.next = newRLevel()
|
|
||||||
}
|
|
||||||
l = n.next
|
|
||||||
}
|
|
||||||
n.msg = buf
|
|
||||||
r.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *RetainList) Match(topic []byte) [][]byte {
|
|
||||||
|
|
||||||
tokens, err := SubscribeTopicCheckAndSpilt(topic)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
results := &RetainResult{}
|
|
||||||
|
|
||||||
r.Lock()
|
|
||||||
l := r.root
|
|
||||||
matchRLevel(l, tokens, results)
|
|
||||||
r.Unlock()
|
|
||||||
// log.Info("results: ", results)
|
|
||||||
return results.msg
|
|
||||||
|
|
||||||
}
|
|
||||||
func matchRLevel(l *rlevel, toks []string, results *RetainResult) {
|
|
||||||
var n *rnode
|
|
||||||
for i, t := range toks {
|
|
||||||
if l == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// log.Info("l info :", l.nodes)
|
|
||||||
if t == "#" {
|
|
||||||
for _, n := range l.nodes {
|
|
||||||
n.GetAll(results)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if t == "+" {
|
|
||||||
for _, n := range l.nodes {
|
|
||||||
if len(t[i+1:]) == 0 {
|
|
||||||
results.msg = append(results.msg, n.msg)
|
|
||||||
} else {
|
|
||||||
matchRLevel(n.next, toks[i+1:], results)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
n = l.nodes[t]
|
|
||||||
if n != nil {
|
|
||||||
l = n.next
|
|
||||||
} else {
|
|
||||||
l = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if n != nil {
|
|
||||||
results.msg = append(results.msg, n.msg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *rnode) GetAll(results *RetainResult) {
|
|
||||||
// log.Info("node 's message: ", string(r.msg))
|
|
||||||
if r.msg != nil && string(r.msg) != "" {
|
|
||||||
results.msg = append(results.msg, r.msg)
|
|
||||||
}
|
|
||||||
l := r.next
|
|
||||||
for _, n := range l.nodes {
|
|
||||||
n.GetAll(results)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,55 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
|
package broker
|
||||||
|
|
||||||
|
import "github.com/eclipse/paho.mqtt.golang/packets"
|
||||||
|
|
||||||
|
func (b *Broker) getSession(cli *client, req *packets.ConnectPacket, resp *packets.ConnackPacket) error {
|
||||||
|
// If CleanSession is set to 0, the server MUST resume communications with the
|
||||||
|
// client based on state from the current session, as identified by the client
|
||||||
|
// identifier. If there is no session associated with the client identifier the
|
||||||
|
// server must create a new session.
|
||||||
|
//
|
||||||
|
// If CleanSession is set to 1, the client and server must discard any previous
|
||||||
|
// session and start a new one. b session lasts as long as the network c
|
||||||
|
// onnection. State data associated with b session must not be reused in any
|
||||||
|
// subsequent session.
|
||||||
|
|
||||||
|
var err error
|
||||||
|
|
||||||
|
// Check to see if the client supplied an ID, if not, generate one and set
|
||||||
|
// clean session.
|
||||||
|
|
||||||
|
if len(req.ClientIdentifier) == 0 {
|
||||||
|
req.CleanSession = true
|
||||||
|
}
|
||||||
|
|
||||||
|
cid := req.ClientIdentifier
|
||||||
|
|
||||||
|
// If CleanSession is NOT set, check the session store for existing session.
|
||||||
|
// If found, return it.
|
||||||
|
if !req.CleanSession {
|
||||||
|
if cli.session, err = b.sessionMgr.Get(cid); err == nil {
|
||||||
|
resp.SessionPresent = true
|
||||||
|
|
||||||
|
if err := cli.session.Update(req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If CleanSession, or no existing session found, then create a new one
|
||||||
|
if cli.session == nil {
|
||||||
|
if cli.session, err = b.sessionMgr.New(cid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
resp.SessionPresent = false
|
||||||
|
|
||||||
|
if err := cli.session.Init(req); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,318 +0,0 @@
|
|||||||
package broker
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
log "github.com/cihub/seelog"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A result structure better optimized for queue subs.
|
|
||||||
type SublistResult struct {
|
|
||||||
psubs []*subscription
|
|
||||||
qsubs []*subscription // don't make this a map, too expensive to iterate
|
|
||||||
}
|
|
||||||
|
|
||||||
// A Sublist stores and efficiently retrieves subscriptions.
|
|
||||||
type Sublist struct {
|
|
||||||
sync.RWMutex
|
|
||||||
cache map[string]*SublistResult
|
|
||||||
root *level
|
|
||||||
}
|
|
||||||
|
|
||||||
// A node contains subscriptions and a pointer to the next level.
|
|
||||||
type node struct {
|
|
||||||
next *level
|
|
||||||
psubs []*subscription
|
|
||||||
qsubs []*subscription
|
|
||||||
}
|
|
||||||
|
|
||||||
// A level represents a group of nodes and special pointers to
|
|
||||||
// wildcard nodes.
|
|
||||||
type level struct {
|
|
||||||
nodes map[string]*node
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new default node.
|
|
||||||
func newNode() *node {
|
|
||||||
return &node{psubs: make([]*subscription, 0, 4), qsubs: make([]*subscription, 0, 4)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create a new default level. We use FNV1A as the hash
|
|
||||||
// algortihm for the tokens, which should be short.
|
|
||||||
func newLevel() *level {
|
|
||||||
return &level{nodes: make(map[string]*node)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// New will create a default sublist
|
|
||||||
func NewSublist() *Sublist {
|
|
||||||
return &Sublist{root: newLevel(), cache: make(map[string]*SublistResult)}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert adds a subscription into the sublist
|
|
||||||
func (s *Sublist) Insert(sub *subscription) error {
|
|
||||||
|
|
||||||
tokens, err := SubscribeTopicCheckAndSpilt(sub.topic)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.Lock()
|
|
||||||
|
|
||||||
l := s.root
|
|
||||||
var n *node
|
|
||||||
for _, t := range tokens {
|
|
||||||
n = l.nodes[t]
|
|
||||||
if n == nil {
|
|
||||||
n = newNode()
|
|
||||||
l.nodes[t] = n
|
|
||||||
}
|
|
||||||
if n.next == nil {
|
|
||||||
n.next = newLevel()
|
|
||||||
}
|
|
||||||
l = n.next
|
|
||||||
}
|
|
||||||
if sub.queue {
|
|
||||||
//check qsub is already exist
|
|
||||||
for i := range n.qsubs {
|
|
||||||
if equal(n.qsubs[i], sub) {
|
|
||||||
n.qsubs[i] = sub
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n.qsubs = append(n.qsubs, sub)
|
|
||||||
} else {
|
|
||||||
//check psub is already exist
|
|
||||||
for i := range n.psubs {
|
|
||||||
if equal(n.psubs[i], sub) {
|
|
||||||
n.psubs[i] = sub
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
n.psubs = append(n.psubs, sub)
|
|
||||||
}
|
|
||||||
|
|
||||||
topic := string(sub.topic)
|
|
||||||
s.addToCache(topic, sub)
|
|
||||||
s.Unlock()
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sublist) addToCache(topic string, sub *subscription) {
|
|
||||||
for k, r := range s.cache {
|
|
||||||
if matchLiteral(k, topic) {
|
|
||||||
// Copy since others may have a reference.
|
|
||||||
nr := copyResult(r)
|
|
||||||
if sub.queue == false {
|
|
||||||
nr.psubs = append(nr.psubs, sub)
|
|
||||||
} else {
|
|
||||||
nr.qsubs = append(nr.qsubs, sub)
|
|
||||||
}
|
|
||||||
s.cache[k] = nr
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sublist) removeFromCache(topic string, sub *subscription) {
|
|
||||||
for k := range s.cache {
|
|
||||||
if !matchLiteral(k, topic) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
// Since someone else may be referecing, can't modify the list
|
|
||||||
// safely, just let it re-populate.
|
|
||||||
delete(s.cache, k)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func matchLiteral(literal, topic string) bool {
|
|
||||||
tok, _ := SubscribeTopicCheckAndSpilt([]byte(topic))
|
|
||||||
li, _ := PublishTopicCheckAndSpilt([]byte(literal))
|
|
||||||
|
|
||||||
for i := 0; i < len(tok); i++ {
|
|
||||||
b := tok[i]
|
|
||||||
switch b {
|
|
||||||
case "+":
|
|
||||||
|
|
||||||
case "#":
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
if b != li[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Deep copy
|
|
||||||
func copyResult(r *SublistResult) *SublistResult {
|
|
||||||
nr := &SublistResult{}
|
|
||||||
nr.psubs = append([]*subscription(nil), r.psubs...)
|
|
||||||
nr.qsubs = append([]*subscription(nil), r.qsubs...)
|
|
||||||
return nr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sublist) Remove(sub *subscription) error {
|
|
||||||
tokens, err := SubscribeTopicCheckAndSpilt(sub.topic)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
s.Lock()
|
|
||||||
defer s.Unlock()
|
|
||||||
|
|
||||||
l := s.root
|
|
||||||
var n *node
|
|
||||||
|
|
||||||
for _, t := range tokens {
|
|
||||||
if l == nil {
|
|
||||||
return errors.New("No Matches subscription Found")
|
|
||||||
}
|
|
||||||
n = l.nodes[t]
|
|
||||||
if n != nil {
|
|
||||||
l = n.next
|
|
||||||
} else {
|
|
||||||
l = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !s.removeFromNode(n, sub) {
|
|
||||||
return errors.New("No Matches subscription Found")
|
|
||||||
}
|
|
||||||
topic := string(sub.topic)
|
|
||||||
s.removeFromCache(topic, sub)
|
|
||||||
return nil
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sublist) removeFromNode(n *node, sub *subscription) (found bool) {
|
|
||||||
if n == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if sub.queue {
|
|
||||||
n.qsubs, found = removeSubFromList(sub, n.qsubs)
|
|
||||||
return found
|
|
||||||
} else {
|
|
||||||
n.psubs, found = removeSubFromList(sub, n.psubs)
|
|
||||||
return found
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Sublist) Match(topic string) *SublistResult {
|
|
||||||
s.RLock()
|
|
||||||
rc, ok := s.cache[topic]
|
|
||||||
s.RUnlock()
|
|
||||||
|
|
||||||
if ok {
|
|
||||||
return rc
|
|
||||||
}
|
|
||||||
|
|
||||||
tokens, err := PublishTopicCheckAndSpilt([]byte(topic))
|
|
||||||
if err != nil {
|
|
||||||
log.Error("\tserver/sublist.go: ", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
result := &SublistResult{}
|
|
||||||
|
|
||||||
s.Lock()
|
|
||||||
l := s.root
|
|
||||||
if len(tokens) > 0 {
|
|
||||||
if tokens[0] == "/" {
|
|
||||||
if _, exist := l.nodes["#"]; exist {
|
|
||||||
addNodeToResults(l.nodes["#"], result)
|
|
||||||
}
|
|
||||||
if _, exist := l.nodes["+"]; exist {
|
|
||||||
matchLevel(l.nodes["/"].next, tokens[1:], result)
|
|
||||||
}
|
|
||||||
if _, exist := l.nodes["/"]; exist {
|
|
||||||
matchLevel(l.nodes["/"].next, tokens[1:], result)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
matchLevel(s.root, tokens, result)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.cache[topic] = result
|
|
||||||
if len(s.cache) > 1024 {
|
|
||||||
for k := range s.cache {
|
|
||||||
delete(s.cache, k)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.Unlock()
|
|
||||||
// log.Info("SublistResult: ", result)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
func matchLevel(l *level, toks []string, results *SublistResult) {
|
|
||||||
var swc, n *node
|
|
||||||
exist := false
|
|
||||||
for i, t := range toks {
|
|
||||||
if l == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, exist = l.nodes["#"]; exist {
|
|
||||||
addNodeToResults(l.nodes["#"], results)
|
|
||||||
}
|
|
||||||
if t != "/" {
|
|
||||||
if swc, exist = l.nodes["+"]; exist {
|
|
||||||
matchLevel(l.nodes["+"].next, toks[i+1:], results)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if _, exist = l.nodes["+"]; exist {
|
|
||||||
addNodeToResults(l.nodes["+"], results)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
n = l.nodes[t]
|
|
||||||
if n != nil {
|
|
||||||
l = n.next
|
|
||||||
} else {
|
|
||||||
l = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if n != nil {
|
|
||||||
addNodeToResults(n, results)
|
|
||||||
}
|
|
||||||
if swc != nil {
|
|
||||||
addNodeToResults(n, results)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// This will add in a node's results to the total results.
|
|
||||||
func addNodeToResults(n *node, results *SublistResult) {
|
|
||||||
results.psubs = append(results.psubs, n.psubs...)
|
|
||||||
results.qsubs = append(results.qsubs, n.qsubs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func removeSubFromList(sub *subscription, sl []*subscription) ([]*subscription, bool) {
|
|
||||||
for i := 0; i < len(sl); i++ {
|
|
||||||
if sl[i] == sub {
|
|
||||||
last := len(sl) - 1
|
|
||||||
sl[i] = sl[last]
|
|
||||||
sl[last] = nil
|
|
||||||
sl = sl[:last]
|
|
||||||
// log.Info("removeSubFromList success")
|
|
||||||
return shrinkAsNeeded(sl), true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sl, false
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks if we need to do a resize. This is for very large growth then
|
|
||||||
// subsequent return to a more normal size from unsubscribe.
|
|
||||||
func shrinkAsNeeded(sl []*subscription) []*subscription {
|
|
||||||
lsl := len(sl)
|
|
||||||
csl := cap(sl)
|
|
||||||
// Don't bother if list not too big
|
|
||||||
if csl <= 8 {
|
|
||||||
return sl
|
|
||||||
}
|
|
||||||
pFree := float32(csl-lsl) / float32(csl)
|
|
||||||
if pFree > 0.50 {
|
|
||||||
return append([]*subscription(nil), sl...)
|
|
||||||
}
|
|
||||||
return sl
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,24 @@
|
|||||||
|
package broker
|
||||||
|
|
||||||
|
var usageStr = `
|
||||||
|
Usage: hmq [options]
|
||||||
|
|
||||||
|
Broker Options:
|
||||||
|
-w, --worker <number> Worker num to process message, perfer (client num)/10. (default 1024)
|
||||||
|
-p, --port <port> Use port for clients (default: 1883)
|
||||||
|
--host <host> Network host to listen on. (default "0.0.0.0")
|
||||||
|
-ws, --wsport <port> Use port for websocket monitoring
|
||||||
|
-wsp,--wspath <path> Use path for websocket monitoring
|
||||||
|
-c, --config <file> Configuration file
|
||||||
|
|
||||||
|
Logging Options:
|
||||||
|
-d, --debug <bool> Enable debugging output (default false)
|
||||||
|
-D Debug and trace
|
||||||
|
|
||||||
|
Cluster Options:
|
||||||
|
-r, --router <rurl> Router who maintenance cluster info
|
||||||
|
-cp, --clusterport <cluster-port> Cluster listen port for others
|
||||||
|
|
||||||
|
Common Options:
|
||||||
|
-h, --help Show this message
|
||||||
|
`
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
package broker
|
|
||||||
|
|
||||||
type Worker struct {
|
|
||||||
WorkerPool chan chan *Message
|
|
||||||
MsgChannel chan *Message
|
|
||||||
quit chan bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewWorker(workerPool chan chan *Message) Worker {
|
|
||||||
return Worker{
|
|
||||||
WorkerPool: workerPool,
|
|
||||||
MsgChannel: make(chan *Message),
|
|
||||||
quit: make(chan bool)}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w Worker) Start() {
|
|
||||||
go func() {
|
|
||||||
for {
|
|
||||||
// register the current worker into the worker queue.
|
|
||||||
w.WorkerPool <- w.MsgChannel
|
|
||||||
select {
|
|
||||||
case msg := <-w.MsgChannel:
|
|
||||||
// we have received a work request.
|
|
||||||
ProcessMessage(msg)
|
|
||||||
case <-w.quit:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop signals the worker to stop listening for work requests.
|
|
||||||
func (w Worker) Stop() {
|
|
||||||
go func() {
|
|
||||||
w.quit <- true
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
@@ -1,11 +1,12 @@
|
|||||||
{
|
{
|
||||||
|
"workerNum": 4096,
|
||||||
"port": "1883",
|
"port": "1883",
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"cluster": {
|
"cluster": {
|
||||||
"host": "0.0.0.0",
|
"host": "0.0.0.0",
|
||||||
"port": "1993",
|
"port": "1993"
|
||||||
"routes": []
|
|
||||||
},
|
},
|
||||||
|
"router": "127.0.0.1:9888",
|
||||||
"tlsPort": "8883",
|
"tlsPort": "8883",
|
||||||
"tlsHost": "0.0.0.0",
|
"tlsHost": "0.0.0.0",
|
||||||
"wsPort": "1888",
|
"wsPort": "1888",
|
||||||
@@ -17,6 +18,6 @@
|
|||||||
"certFile": "ssl/server/cert.pem",
|
"certFile": "ssl/server/cert.pem",
|
||||||
"keyFile": "ssl/server/key.pem"
|
"keyFile": "ssl/server/key.pem"
|
||||||
},
|
},
|
||||||
"acl": true,
|
"acl": false,
|
||||||
"aclConf": "conf/acl.conf"
|
"aclConf": "conf/acl.conf"
|
||||||
}
|
}
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
module github.com/fhmq/hmq
|
||||||
|
|
||||||
|
go 1.12
|
||||||
|
|
||||||
|
require (
|
||||||
|
github.com/StackExchange/wmi v0.0.0-20181212234831-e0a55b97c705 // indirect
|
||||||
|
github.com/bitly/go-simplejson v0.5.0
|
||||||
|
github.com/eclipse/paho.mqtt.golang v1.2.0
|
||||||
|
github.com/fsnotify/fsnotify v1.4.7
|
||||||
|
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||||
|
github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e
|
||||||
|
github.com/shirou/gopsutil v2.18.12+incompatible
|
||||||
|
github.com/stretchr/testify v1.3.0
|
||||||
|
go.uber.org/atomic v1.3.2 // indirect
|
||||||
|
go.uber.org/multierr v1.1.0 // indirect
|
||||||
|
go.uber.org/zap v1.9.1
|
||||||
|
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd // indirect
|
||||||
|
golang.org/x/lint v0.0.0-20190409202823-959b441ac422 // indirect
|
||||||
|
golang.org/x/net v0.0.0-20190424024845-afe8014c977f
|
||||||
|
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 // indirect
|
||||||
|
golang.org/x/tools v0.0.0-20190424031103-cb2dda6eabdf // indirect
|
||||||
|
)
|
||||||
@@ -0,0 +1,45 @@
|
|||||||
|
github.com/StackExchange/wmi v0.0.0-20181212234831-e0a55b97c705 h1:UUppSQnhf4Yc6xGxSkoQpPhb7RVzuv5Nb1mwJ5VId9s=
|
||||||
|
github.com/StackExchange/wmi v0.0.0-20181212234831-e0a55b97c705/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg=
|
||||||
|
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
|
||||||
|
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
|
||||||
|
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||||
|
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||||
|
github.com/eclipse/paho.mqtt.golang v1.2.0 h1:1F8mhG9+aO5/xpdtFkW4SxOJB67ukuDC3t2y2qayIX0=
|
||||||
|
github.com/eclipse/paho.mqtt.golang v1.2.0/go.mod h1:H9keYFcgq3Qr5OUJm/JZI/i6U7joQ8SYLhZwfeOo6Ts=
|
||||||
|
github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I=
|
||||||
|
github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo=
|
||||||
|
github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI=
|
||||||
|
github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||||
|
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||||
|
github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e h1:uO75wNGioszjmIzcY/tvdDYKRLVvzggtAmmJkn9j4GQ=
|
||||||
|
github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e/go.mod h1:tm/wZFQ8e24NYaBGIlnO2WGCAi67re4HHuOm0sftE/M=
|
||||||
|
github.com/shirou/gopsutil v2.18.12+incompatible h1:1eaJvGomDnH74/5cF4CTmTbLHAriGFsTZppLXDX93OM=
|
||||||
|
github.com/shirou/gopsutil v2.18.12+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||||
|
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||||
|
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||||
|
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||||
|
go.uber.org/atomic v1.3.2 h1:2Oa65PReHzfn29GpvgsYwloV9AVFHPDk8tYxt2c2tr4=
|
||||||
|
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||||
|
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
|
||||||
|
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||||
|
go.uber.org/zap v1.9.1 h1:XCJQEf3W6eZaVwhRBof6ImoYGJSITeKWsyeh3HFu/5o=
|
||||||
|
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||||
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
|
golang.org/x/crypto v0.0.0-20190422183909-d864b10871cd/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
|
golang.org/x/lint v0.0.0-20190409202823-959b441ac422 h1:QzoH/1pFpZguR8NrRHLcO6jKqfv2zpuSqZLgdm7ZmjI=
|
||||||
|
golang.org/x/lint v0.0.0-20190409202823-959b441ac422/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||||
|
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/net v0.0.0-20190424024845-afe8014c977f h1:uALRiwYevCJtciRa4mKKFkrs5jY4F2OTf1D2sfi1swY=
|
||||||
|
golang.org/x/net v0.0.0-20190424024845-afe8014c977f/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
|
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||||
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/sys v0.0.0-20190422165155-953cdadca894 h1:Cz4ceDQGXuKRnVBDTS23GTn/pU5OE2C0WrNTOYK1Uuc=
|
||||||
|
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
|
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||||
|
golang.org/x/tools v0.0.0-20190424031103-cb2dda6eabdf h1:Yv3pKbXQqpdhrt53r+Yr1XveoqVgIFTCQdaamSalWwM=
|
||||||
|
golang.org/x/tools v0.0.0-20190424031103-cb2dda6eabdf/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||||
@@ -1,3 +1,5 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>*/
|
||||||
package acl
|
package acl
|
||||||
|
|
||||||
import "strings"
|
import "strings"
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
|||||||
@@ -1,16 +0,0 @@
|
|||||||
# This is the official list of SurgeMQ authors for copyright purposes.
|
|
||||||
|
|
||||||
# If you are submitting a patch, please add your name or the name of the
|
|
||||||
# organization which holds the copyright to this list in alphabetical order.
|
|
||||||
|
|
||||||
# Names should be added to this file as
|
|
||||||
# Name <email address>
|
|
||||||
# The email address is not required for organizations.
|
|
||||||
# Please keep the list sorted.
|
|
||||||
|
|
||||||
|
|
||||||
# Individual Persons
|
|
||||||
|
|
||||||
Jian Zhen <zhenjl@gmail.com>
|
|
||||||
|
|
||||||
# Organizations
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# Contributing Guidelines
|
|
||||||
|
|
||||||
## Reporting Issues
|
|
||||||
|
|
||||||
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/surgemq/message/issues?state=open) or was [recently closed](https://github.com/surgemq/message/issues?direction=desc&page=1&sort=updated&state=closed).
|
|
||||||
|
|
||||||
Please provide the following minimum information:
|
|
||||||
* Your SurgeMQ version (or git SHA)
|
|
||||||
* Your Go version (run `go version` in your console)
|
|
||||||
* A detailed issue description
|
|
||||||
* Error Log if present
|
|
||||||
* If possible, a short example
|
|
||||||
|
|
||||||
|
|
||||||
## Contributing Code
|
|
||||||
|
|
||||||
By contributing to this project, you share your code under the Apache License, Version 2.0, as specified in the LICENSE file.
|
|
||||||
Don't forget to add yourself to the AUTHORS file.
|
|
||||||
|
|
||||||
### Pull Requests Checklist
|
|
||||||
|
|
||||||
Please check the following points before submitting your pull request:
|
|
||||||
- [x] Code compiles correctly
|
|
||||||
- [x] Created tests, if possible
|
|
||||||
- [x] All tests pass
|
|
||||||
- [x] Extended the README / documentation, if necessary
|
|
||||||
- [x] Added yourself to the AUTHORS file
|
|
||||||
|
|
||||||
### Code Review
|
|
||||||
|
|
||||||
Everyone is invited to review and comment on pull requests.
|
|
||||||
If it looks fine to you, comment with "LGTM" (Looks good to me).
|
|
||||||
|
|
||||||
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
|
|
||||||
|
|
||||||
Before merging the Pull Request, at least one [team member](https://github.com/orgs/surgemq/people) must have commented with "LGTM".
|
|
||||||
@@ -1,136 +0,0 @@
|
|||||||
Package message is an encoder/decoder library for MQTT 3.1 and 3.1.1 messages. You can
|
|
||||||
find the MQTT specs at the following locations:
|
|
||||||
|
|
||||||
> 3.1.1 - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/
|
|
||||||
> 3.1 - http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
|
|
||||||
|
|
||||||
From the spec:
|
|
||||||
|
|
||||||
> MQTT is a Client Server publish/subscribe messaging transport protocol. It is
|
|
||||||
> light weight, open, simple, and designed so as to be easy to implement. These
|
|
||||||
> characteristics make it ideal for use in many situations, including constrained
|
|
||||||
> environments such as for communication in Machine to Machine (M2M) and Internet
|
|
||||||
> of Things (IoT) contexts where a small code footprint is required and/or network
|
|
||||||
> bandwidth is at a premium.
|
|
||||||
>
|
|
||||||
> The MQTT protocol works by exchanging a series of MQTT messages in a defined way.
|
|
||||||
> The protocol runs over TCP/IP, or over other network protocols that provide
|
|
||||||
> ordered, lossless, bi-directional connections.
|
|
||||||
|
|
||||||
|
|
||||||
There are two main items to take note in this package. The first is
|
|
||||||
|
|
||||||
```
|
|
||||||
type MessageType byte
|
|
||||||
```
|
|
||||||
|
|
||||||
MessageType is the type representing the MQTT packet types. In the MQTT spec, MQTT
|
|
||||||
control packet type is represented as a 4-bit unsigned value. MessageType receives
|
|
||||||
several methods that returns string representations of the names and descriptions.
|
|
||||||
|
|
||||||
Also, one of the methods is New(). It returns a new Message object based on the mtype
|
|
||||||
parameter. For example:
|
|
||||||
|
|
||||||
```
|
|
||||||
m, err := CONNECT.New()
|
|
||||||
msg := m.(*ConnectMessage)
|
|
||||||
```
|
|
||||||
|
|
||||||
This would return a PublishMessage struct, but mapped to the Message interface. You can
|
|
||||||
then type assert it back to a *PublishMessage. Another way to create a new
|
|
||||||
PublishMessage is to call
|
|
||||||
|
|
||||||
```
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
```
|
|
||||||
|
|
||||||
Every message type has a New function that returns a new message. The list of available
|
|
||||||
message types are defined as constants below.
|
|
||||||
|
|
||||||
As you may have noticed, the second important item is the Message interface. It defines
|
|
||||||
several methods that are common to all messages, including Name(), Desc(), and Type().
|
|
||||||
Most importantly, it also defines the Encode() and Decode() methods.
|
|
||||||
|
|
||||||
```
|
|
||||||
Encode() (io.Reader, int, error)
|
|
||||||
Decode(io.Reader) (int, error)
|
|
||||||
```
|
|
||||||
|
|
||||||
Encode returns an io.Reader in which the encoded bytes can be read. The second return
|
|
||||||
value is the number of bytes encoded, so the caller knows how many bytes there will be.
|
|
||||||
If Encode returns an error, then the first two return values should be considered invalid.
|
|
||||||
Any changes to the message after Encode() is called will invalidate the io.Reader.
|
|
||||||
|
|
||||||
Decode reads from the io.Reader parameter until a full message is decoded, or when io.Reader
|
|
||||||
returns EOF or error. The first return value is the number of bytes read from io.Reader.
|
|
||||||
The second is error if Decode encounters any problems.
|
|
||||||
|
|
||||||
With these in mind, we can now do:
|
|
||||||
|
|
||||||
```
|
|
||||||
// Create a new CONNECT message
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
|
|
||||||
// Set the appropriate parameters
|
|
||||||
msg.SetWillQos(1)
|
|
||||||
msg.SetVersion(4)
|
|
||||||
msg.SetCleanSession(true)
|
|
||||||
msg.SetClientId([]byte("surgemq"))
|
|
||||||
msg.SetKeepAlive(10)
|
|
||||||
msg.SetWillTopic([]byte("will"))
|
|
||||||
msg.SetWillMessage([]byte("send me home"))
|
|
||||||
msg.SetUsername([]byte("surgemq"))
|
|
||||||
msg.SetPassword([]byte("verysecret"))
|
|
||||||
|
|
||||||
// Encode the message and get the io.Reader
|
|
||||||
r, n, err := msg.Encode()
|
|
||||||
if err == nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write n bytes into the connection
|
|
||||||
m, err := io.CopyN(conn, r, int64(n))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Sent %d bytes of %s message", m, msg.Name())
|
|
||||||
```
|
|
||||||
|
|
||||||
To receive a CONNECT message from a connection, we can do:
|
|
||||||
|
|
||||||
```
|
|
||||||
// Create a new CONNECT message
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
|
|
||||||
// Decode the message by reading from conn
|
|
||||||
n, err := msg.Decode(conn)
|
|
||||||
```
|
|
||||||
|
|
||||||
If you don't know what type of message is coming down the pipe, you can do something like this:
|
|
||||||
|
|
||||||
```
|
|
||||||
// Create a buffered IO reader for the connection
|
|
||||||
br := bufio.NewReader(conn)
|
|
||||||
|
|
||||||
// Peek at the first byte, which contains the message type
|
|
||||||
b, err := br.Peek(1)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the type from the first byte
|
|
||||||
t := MessageType(b[0] >> 4)
|
|
||||||
|
|
||||||
// Create a new message
|
|
||||||
msg, err := t.New()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode it from the bufio.Reader
|
|
||||||
n, err := msg.Decode(br)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
```
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// The CONNACK Packet is the packet sent by the Server in response to a CONNECT Packet
|
|
||||||
// received from a Client. The first packet sent from the Server to the Client MUST
|
|
||||||
// be a CONNACK Packet [MQTT-3.2.0-1].
|
|
||||||
//
|
|
||||||
// If the Client does not receive a CONNACK Packet from the Server within a reasonable
|
|
||||||
// amount of time, the Client SHOULD close the Network Connection. A "reasonable" amount
|
|
||||||
// of time depends on the type of application and the communications infrastructure.
|
|
||||||
type ConnackMessage struct {
|
|
||||||
header
|
|
||||||
|
|
||||||
sessionPresent bool
|
|
||||||
returnCode ConnackCode
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*ConnackMessage)(nil)
|
|
||||||
|
|
||||||
// NewConnackMessage creates a new CONNACK message
|
|
||||||
func NewConnackMessage() *ConnackMessage {
|
|
||||||
msg := &ConnackMessage{}
|
|
||||||
msg.SetType(CONNACK)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a string representation of the CONNACK message
|
|
||||||
func (this ConnackMessage) String() string {
|
|
||||||
return fmt.Sprintf("%s, Session Present=%t, Return code=%q\n", this.header, this.sessionPresent, this.returnCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SessionPresent returns the session present flag value
|
|
||||||
func (this *ConnackMessage) SessionPresent() bool {
|
|
||||||
return this.sessionPresent
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSessionPresent sets the value of the session present flag
|
|
||||||
func (this *ConnackMessage) SetSessionPresent(v bool) {
|
|
||||||
if v {
|
|
||||||
this.sessionPresent = true
|
|
||||||
} else {
|
|
||||||
this.sessionPresent = false
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReturnCode returns the return code received for the CONNECT message. The return
|
|
||||||
// type is an error
|
|
||||||
func (this *ConnackMessage) ReturnCode() ConnackCode {
|
|
||||||
return this.returnCode
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnackMessage) SetReturnCode(ret ConnackCode) {
|
|
||||||
this.returnCode = ret
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnackMessage) Len() int {
|
|
||||||
if !this.dirty {
|
|
||||||
return len(this.dbuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.msglen() + ml
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnackMessage) Decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.decode(src)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
b := src[total]
|
|
||||||
|
|
||||||
if b&254 != 0 {
|
|
||||||
return 0, fmt.Errorf("connack/Decode: Bits 7-1 in Connack Acknowledge Flags byte (1) are not 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
this.sessionPresent = b&0x1 == 1
|
|
||||||
total++
|
|
||||||
|
|
||||||
b = src[total]
|
|
||||||
|
|
||||||
// Read return code
|
|
||||||
if b > 5 {
|
|
||||||
return 0, fmt.Errorf("connack/Decode: Invalid CONNACK return code (%d)", b)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.returnCode = ConnackCode(b)
|
|
||||||
total++
|
|
||||||
|
|
||||||
this.dirty = false
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnackMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("connack/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CONNACK remaining length fixed at 2 bytes
|
|
||||||
hl := this.header.msglen()
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if len(dst) < hl+ml {
|
|
||||||
return 0, fmt.Errorf("connack/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.encode(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.sessionPresent {
|
|
||||||
dst[total] = 1
|
|
||||||
}
|
|
||||||
total++
|
|
||||||
|
|
||||||
if this.returnCode > 5 {
|
|
||||||
return total, fmt.Errorf("connack/Encode: Invalid CONNACK return code (%d)", this.returnCode)
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[total] = this.returnCode.Value()
|
|
||||||
total++
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnackMessage) msglen() int {
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConnackMessageFields(t *testing.T) {
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
|
|
||||||
msg.SetSessionPresent(true)
|
|
||||||
require.True(t, msg.SessionPresent(), "Error setting session present flag.")
|
|
||||||
|
|
||||||
msg.SetSessionPresent(false)
|
|
||||||
require.False(t, msg.SessionPresent(), "Error setting session present flag.")
|
|
||||||
|
|
||||||
msg.SetReturnCode(ConnectionAccepted)
|
|
||||||
require.Equal(t, ConnectionAccepted, msg.ReturnCode(), "Error setting return code.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnackMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNACK << 4),
|
|
||||||
2,
|
|
||||||
0, // session not present
|
|
||||||
0, // connection accepted
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.False(t, msg.SessionPresent(), "Error decoding session present flag.")
|
|
||||||
require.Equal(t, ConnectionAccepted, msg.ReturnCode(), "Error decoding return code.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// testing wrong message length
|
|
||||||
func TestConnackMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNACK << 4),
|
|
||||||
3,
|
|
||||||
0, // session not present
|
|
||||||
0, // connection accepted
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
require.Error(t, err, "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// testing wrong message size
|
|
||||||
func TestConnackMessageDecode3(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNACK << 4),
|
|
||||||
2,
|
|
||||||
0, // session not present
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
require.Error(t, err, "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// testing wrong reserve bits
|
|
||||||
func TestConnackMessageDecode4(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNACK << 4),
|
|
||||||
2,
|
|
||||||
64, // <- wrong size
|
|
||||||
0, // connection accepted
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
require.Error(t, err, "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// testing invalid return code
|
|
||||||
func TestConnackMessageDecode5(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNACK << 4),
|
|
||||||
2,
|
|
||||||
0,
|
|
||||||
6, // <- wrong code
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
require.Error(t, err, "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnackMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNACK << 4),
|
|
||||||
2,
|
|
||||||
1, // session present
|
|
||||||
0, // connection accepted
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
msg.SetReturnCode(ConnectionAccepted)
|
|
||||||
msg.SetSessionPresent(true)
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error encoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error encoding connack message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestConnackDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNACK << 4),
|
|
||||||
2,
|
|
||||||
0, // session not present
|
|
||||||
0, // connection accepted
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnackMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,89 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
// ConnackCode is the type representing the return code in the CONNACK message,
|
|
||||||
// returned after the initial CONNECT message
|
|
||||||
type ConnackCode byte
|
|
||||||
|
|
||||||
const (
|
|
||||||
// Connection accepted
|
|
||||||
ConnectionAccepted ConnackCode = iota
|
|
||||||
|
|
||||||
// The Server does not support the level of the MQTT protocol requested by the Client
|
|
||||||
ErrInvalidProtocolVersion
|
|
||||||
|
|
||||||
// The Client identifier is correct UTF-8 but not allowed by the server
|
|
||||||
ErrIdentifierRejected
|
|
||||||
|
|
||||||
// The Network Connection has been made but the MQTT service is unavailable
|
|
||||||
ErrServerUnavailable
|
|
||||||
|
|
||||||
// The data in the user name or password is malformed
|
|
||||||
ErrBadUsernameOrPassword
|
|
||||||
|
|
||||||
// The Client is not authorized to connect
|
|
||||||
ErrNotAuthorized
|
|
||||||
)
|
|
||||||
|
|
||||||
// Value returns the value of the ConnackCode, which is just the byte representation
|
|
||||||
func (this ConnackCode) Value() byte {
|
|
||||||
return byte(this)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Desc returns the description of the ConnackCode
|
|
||||||
func (this ConnackCode) Desc() string {
|
|
||||||
switch this {
|
|
||||||
case 0:
|
|
||||||
return "Connection accepted"
|
|
||||||
case 1:
|
|
||||||
return "The Server does not support the level of the MQTT protocol requested by the Client"
|
|
||||||
case 2:
|
|
||||||
return "The Client identifier is correct UTF-8 but not allowed by the server"
|
|
||||||
case 3:
|
|
||||||
return "The Network Connection has been made but the MQTT service is unavailable"
|
|
||||||
case 4:
|
|
||||||
return "The data in the user name or password is malformed"
|
|
||||||
case 5:
|
|
||||||
return "The Client is not authorized to connect"
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid checks to see if the ConnackCode is valid. Currently valid codes are <= 5
|
|
||||||
func (this ConnackCode) Valid() bool {
|
|
||||||
return this <= 5
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error returns the corresonding error string for the ConnackCode
|
|
||||||
func (this ConnackCode) Error() string {
|
|
||||||
switch this {
|
|
||||||
case 0:
|
|
||||||
return "Connection accepted"
|
|
||||||
case 1:
|
|
||||||
return "Connection Refused, unacceptable protocol version"
|
|
||||||
case 2:
|
|
||||||
return "Connection Refused, identifier rejected"
|
|
||||||
case 3:
|
|
||||||
return "Connection Refused, Server unavailable"
|
|
||||||
case 4:
|
|
||||||
return "Connection Refused, bad user name or password"
|
|
||||||
case 5:
|
|
||||||
return "Connection Refused, not authorized"
|
|
||||||
}
|
|
||||||
|
|
||||||
return "Unknown error"
|
|
||||||
}
|
|
||||||
@@ -1,635 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
"regexp"
|
|
||||||
)
|
|
||||||
|
|
||||||
var clientIdRegexp *regexp.Regexp
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
// Added space for Paho compliance test
|
|
||||||
// Added underscore (_) for MQTT C client test
|
|
||||||
clientIdRegexp = regexp.MustCompile("^[0-9a-zA-Z _]*$")
|
|
||||||
}
|
|
||||||
|
|
||||||
// After a Network Connection is established by a Client to a Server, the first Packet
|
|
||||||
// sent from the Client to the Server MUST be a CONNECT Packet [MQTT-3.1.0-1].
|
|
||||||
//
|
|
||||||
// A Client can only send the CONNECT Packet once over a Network Connection. The Server
|
|
||||||
// MUST process a second CONNECT Packet sent from a Client as a protocol violation and
|
|
||||||
// disconnect the Client [MQTT-3.1.0-2]. See section 4.8 for information about
|
|
||||||
// handling errors.
|
|
||||||
type ConnectMessage struct {
|
|
||||||
header
|
|
||||||
|
|
||||||
// 7: username flag
|
|
||||||
// 6: password flag
|
|
||||||
// 5: will retain
|
|
||||||
// 4-3: will QoS
|
|
||||||
// 2: will flag
|
|
||||||
// 1: clean session
|
|
||||||
// 0: reserved
|
|
||||||
connectFlags byte
|
|
||||||
|
|
||||||
version byte
|
|
||||||
|
|
||||||
keepAlive uint16
|
|
||||||
|
|
||||||
protoName,
|
|
||||||
clientId,
|
|
||||||
willTopic,
|
|
||||||
willMessage,
|
|
||||||
username,
|
|
||||||
password []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*ConnectMessage)(nil)
|
|
||||||
|
|
||||||
// NewConnectMessage creates a new CONNECT message.
|
|
||||||
func NewConnectMessage() *ConnectMessage {
|
|
||||||
msg := &ConnectMessage{}
|
|
||||||
msg.SetType(CONNECT)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a string representation of the CONNECT message
|
|
||||||
func (this ConnectMessage) String() string {
|
|
||||||
return fmt.Sprintf("%s, Connect Flags=%08b, Version=%d, KeepAlive=%d, Client ID=%q, Will Topic=%q, Will Message=%q, Username=%q, Password=%q",
|
|
||||||
this.header,
|
|
||||||
this.connectFlags,
|
|
||||||
this.Version(),
|
|
||||||
this.KeepAlive(),
|
|
||||||
this.ClientId(),
|
|
||||||
this.WillTopic(),
|
|
||||||
this.WillMessage(),
|
|
||||||
this.Username(),
|
|
||||||
this.Password(),
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Version returns the the 8 bit unsigned value that represents the revision level
|
|
||||||
// of the protocol used by the Client. The value of the Protocol Level field for
|
|
||||||
// the version 3.1.1 of the protocol is 4 (0x04).
|
|
||||||
func (this *ConnectMessage) Version() byte {
|
|
||||||
return this.version
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetVersion sets the version value of the CONNECT message
|
|
||||||
func (this *ConnectMessage) SetVersion(v byte) error {
|
|
||||||
if _, ok := SupportedVersions[v]; !ok {
|
|
||||||
return fmt.Errorf("connect/SetVersion: Invalid version number %d", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.version = v
|
|
||||||
this.dirty = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// CleanSession returns the bit that specifies the handling of the Session state.
|
|
||||||
// The Client and Server can store Session state to enable reliable messaging to
|
|
||||||
// continue across a sequence of Network Connections. This bit is used to control
|
|
||||||
// the lifetime of the Session state.
|
|
||||||
func (this *ConnectMessage) CleanSession() bool {
|
|
||||||
return ((this.connectFlags >> 1) & 0x1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetCleanSession sets the bit that specifies the handling of the Session state.
|
|
||||||
func (this *ConnectMessage) SetCleanSession(v bool) {
|
|
||||||
if v {
|
|
||||||
this.connectFlags |= 0x2 // 00000010
|
|
||||||
} else {
|
|
||||||
this.connectFlags &= 253 // 11111101
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// WillFlag returns the bit that specifies whether a Will Message should be stored
|
|
||||||
// on the server. If the Will Flag is set to 1 this indicates that, if the Connect
|
|
||||||
// request is accepted, a Will Message MUST be stored on the Server and associated
|
|
||||||
// with the Network Connection.
|
|
||||||
func (this *ConnectMessage) WillFlag() bool {
|
|
||||||
return ((this.connectFlags >> 2) & 0x1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWillFlag sets the bit that specifies whether a Will Message should be stored
|
|
||||||
// on the server.
|
|
||||||
func (this *ConnectMessage) SetWillFlag(v bool) {
|
|
||||||
if v {
|
|
||||||
this.connectFlags |= 0x4 // 00000100
|
|
||||||
} else {
|
|
||||||
this.connectFlags &= 251 // 11111011
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// WillQos returns the two bits that specify the QoS level to be used when publishing
|
|
||||||
// the Will Message.
|
|
||||||
func (this *ConnectMessage) WillQos() byte {
|
|
||||||
return (this.connectFlags >> 3) & 0x3
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWillQos sets the two bits that specify the QoS level to be used when publishing
|
|
||||||
// the Will Message.
|
|
||||||
func (this *ConnectMessage) SetWillQos(qos byte) error {
|
|
||||||
if qos != QosAtMostOnce && qos != QosAtLeastOnce && qos != QosExactlyOnce {
|
|
||||||
return fmt.Errorf("connect/SetWillQos: Invalid QoS level %d", qos)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.connectFlags = (this.connectFlags & 231) | (qos << 3) // 231 = 11100111
|
|
||||||
this.dirty = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WillRetain returns the bit specifies if the Will Message is to be Retained when it
|
|
||||||
// is published.
|
|
||||||
func (this *ConnectMessage) WillRetain() bool {
|
|
||||||
return ((this.connectFlags >> 5) & 0x1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWillRetain sets the bit specifies if the Will Message is to be Retained when it
|
|
||||||
// is published.
|
|
||||||
func (this *ConnectMessage) SetWillRetain(v bool) {
|
|
||||||
if v {
|
|
||||||
this.connectFlags |= 32 // 00100000
|
|
||||||
} else {
|
|
||||||
this.connectFlags &= 223 // 11011111
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// UsernameFlag returns the bit that specifies whether a user name is present in the
|
|
||||||
// payload.
|
|
||||||
func (this *ConnectMessage) UsernameFlag() bool {
|
|
||||||
return ((this.connectFlags >> 7) & 0x1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUsernameFlag sets the bit that specifies whether a user name is present in the
|
|
||||||
// payload.
|
|
||||||
func (this *ConnectMessage) SetUsernameFlag(v bool) {
|
|
||||||
if v {
|
|
||||||
this.connectFlags |= 128 // 10000000
|
|
||||||
} else {
|
|
||||||
this.connectFlags &= 127 // 01111111
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// PasswordFlag returns the bit that specifies whether a password is present in the
|
|
||||||
// payload.
|
|
||||||
func (this *ConnectMessage) PasswordFlag() bool {
|
|
||||||
return ((this.connectFlags >> 6) & 0x1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPasswordFlag sets the bit that specifies whether a password is present in the
|
|
||||||
// payload.
|
|
||||||
func (this *ConnectMessage) SetPasswordFlag(v bool) {
|
|
||||||
if v {
|
|
||||||
this.connectFlags |= 64 // 01000000
|
|
||||||
} else {
|
|
||||||
this.connectFlags &= 191 // 10111111
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// KeepAlive returns a time interval measured in seconds. Expressed as a 16-bit word,
|
|
||||||
// it is the maximum time interval that is permitted to elapse between the point at
|
|
||||||
// which the Client finishes transmitting one Control Packet and the point it starts
|
|
||||||
// sending the next.
|
|
||||||
func (this *ConnectMessage) KeepAlive() uint16 {
|
|
||||||
return this.keepAlive
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKeepAlive sets the time interval in which the server should keep the connection
|
|
||||||
// alive.
|
|
||||||
func (this *ConnectMessage) SetKeepAlive(v uint16) {
|
|
||||||
this.keepAlive = v
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// ClientId returns an ID that identifies the Client to the Server. Each Client
|
|
||||||
// connecting to the Server has a unique ClientId. The ClientId MUST be used by
|
|
||||||
// Clients and by Servers to identify state that they hold relating to this MQTT
|
|
||||||
// Session between the Client and the Server
|
|
||||||
func (this *ConnectMessage) ClientId() []byte {
|
|
||||||
return this.clientId
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetClientId sets an ID that identifies the Client to the Server.
|
|
||||||
func (this *ConnectMessage) SetClientId(v []byte) error {
|
|
||||||
if len(v) > 0 && !this.validClientId(v) {
|
|
||||||
return ErrIdentifierRejected
|
|
||||||
}
|
|
||||||
|
|
||||||
this.clientId = v
|
|
||||||
this.dirty = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// WillTopic returns the topic in which the Will Message should be published to.
|
|
||||||
// If the Will Flag is set to 1, the Will Topic must be in the payload.
|
|
||||||
func (this *ConnectMessage) WillTopic() []byte {
|
|
||||||
return this.willTopic
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWillTopic sets the topic in which the Will Message should be published to.
|
|
||||||
func (this *ConnectMessage) SetWillTopic(v []byte) {
|
|
||||||
this.willTopic = v
|
|
||||||
|
|
||||||
if len(v) > 0 {
|
|
||||||
this.SetWillFlag(true)
|
|
||||||
} else if len(this.willMessage) == 0 {
|
|
||||||
this.SetWillFlag(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// WillMessage returns the Will Message that is to be published to the Will Topic.
|
|
||||||
func (this *ConnectMessage) WillMessage() []byte {
|
|
||||||
return this.willMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetWillMessage sets the Will Message that is to be published to the Will Topic.
|
|
||||||
func (this *ConnectMessage) SetWillMessage(v []byte) {
|
|
||||||
this.willMessage = v
|
|
||||||
|
|
||||||
if len(v) > 0 {
|
|
||||||
this.SetWillFlag(true)
|
|
||||||
} else if len(this.willTopic) == 0 {
|
|
||||||
this.SetWillFlag(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Username returns the username from the payload. If the User Name Flag is set to 1,
|
|
||||||
// this must be in the payload. It can be used by the Server for authentication and
|
|
||||||
// authorization.
|
|
||||||
func (this *ConnectMessage) Username() []byte {
|
|
||||||
return this.username
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUsername sets the username for authentication.
|
|
||||||
func (this *ConnectMessage) SetUsername(v []byte) {
|
|
||||||
this.username = v
|
|
||||||
|
|
||||||
if len(v) > 0 {
|
|
||||||
this.SetUsernameFlag(true)
|
|
||||||
} else {
|
|
||||||
this.SetUsernameFlag(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Password returns the password from the payload. If the Password Flag is set to 1,
|
|
||||||
// this must be in the payload. It can be used by the Server for authentication and
|
|
||||||
// authorization.
|
|
||||||
func (this *ConnectMessage) Password() []byte {
|
|
||||||
return this.password
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPassword sets the username for authentication.
|
|
||||||
func (this *ConnectMessage) SetPassword(v []byte) {
|
|
||||||
this.password = v
|
|
||||||
|
|
||||||
if len(v) > 0 {
|
|
||||||
this.SetPasswordFlag(true)
|
|
||||||
} else {
|
|
||||||
this.SetPasswordFlag(false)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnectMessage) Len() int {
|
|
||||||
if !this.dirty {
|
|
||||||
return len(this.dbuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.msglen() + ml
|
|
||||||
}
|
|
||||||
|
|
||||||
// For the CONNECT message, the error returned could be a ConnackReturnCode, so
|
|
||||||
// be sure to check that. Otherwise it's a generic error. If a generic error is
|
|
||||||
// returned, this Message should be considered invalid.
|
|
||||||
//
|
|
||||||
// Caller should call ValidConnackError(err) to see if the returned error is
|
|
||||||
// a Connack error. If so, caller should send the Client back the corresponding
|
|
||||||
// CONNACK message.
|
|
||||||
func (this *ConnectMessage) Decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.decode(src[total:])
|
|
||||||
if err != nil {
|
|
||||||
return total + n, err
|
|
||||||
}
|
|
||||||
total += n
|
|
||||||
|
|
||||||
if n, err = this.decodeMessage(src[total:]); err != nil {
|
|
||||||
return total + n, err
|
|
||||||
}
|
|
||||||
total += n
|
|
||||||
|
|
||||||
this.dirty = false
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnectMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("connect/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.Type() != CONNECT {
|
|
||||||
return 0, fmt.Errorf("connect/Encode: Invalid message type. Expecting %d, got %d", CONNECT, this.Type())
|
|
||||||
}
|
|
||||||
|
|
||||||
_, ok := SupportedVersions[this.version]
|
|
||||||
if !ok {
|
|
||||||
return 0, ErrInvalidProtocolVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
hl := this.header.msglen()
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if len(dst) < hl+ml {
|
|
||||||
return 0, fmt.Errorf("connect/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.encode(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = this.encodeMessage(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnectMessage) encodeMessage(dst []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := writeLPBytes(dst[total:], []byte(SupportedVersions[this.version]))
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[total] = this.version
|
|
||||||
total += 1
|
|
||||||
|
|
||||||
dst[total] = this.connectFlags
|
|
||||||
total += 1
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(dst[total:], this.keepAlive)
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
n, err = writeLPBytes(dst[total:], this.clientId)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.WillFlag() {
|
|
||||||
n, err = writeLPBytes(dst[total:], this.willTopic)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = writeLPBytes(dst[total:], this.willMessage)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// According to the 3.1 spec, it's possible that the usernameFlag is set,
|
|
||||||
// but the username string is missing.
|
|
||||||
if this.UsernameFlag() && len(this.username) > 0 {
|
|
||||||
n, err = writeLPBytes(dst[total:], this.username)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
|
||||||
// but the password string is missing.
|
|
||||||
if this.PasswordFlag() && len(this.password) > 0 {
|
|
||||||
n, err = writeLPBytes(dst[total:], this.password)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnectMessage) decodeMessage(src []byte) (int, error) {
|
|
||||||
var err error
|
|
||||||
n, total := 0, 0
|
|
||||||
|
|
||||||
this.protoName, n, err = readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
this.version = src[total]
|
|
||||||
total++
|
|
||||||
|
|
||||||
if verstr, ok := SupportedVersions[this.version]; !ok {
|
|
||||||
return total, ErrInvalidProtocolVersion
|
|
||||||
} else if verstr != string(this.protoName) {
|
|
||||||
return total, ErrInvalidProtocolVersion
|
|
||||||
}
|
|
||||||
|
|
||||||
this.connectFlags = src[total]
|
|
||||||
total++
|
|
||||||
|
|
||||||
if this.connectFlags&0x1 != 0 {
|
|
||||||
return total, fmt.Errorf("connect/decodeMessage: Connect Flags reserved bit 0 is not 0")
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.WillQos() > QosExactlyOnce {
|
|
||||||
return total, fmt.Errorf("connect/decodeMessage: Invalid QoS level (%d) for %s message", this.WillQos(), this.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
if !this.WillFlag() && (this.WillRetain() || this.WillQos() != QosAtMostOnce) {
|
|
||||||
return total, fmt.Errorf("connect/decodeMessage: Protocol violation: If the Will Flag (%t) is set to 0 the Will QoS (%d) and Will Retain (%t) fields MUST be set to zero", this.WillFlag(), this.WillQos(), this.WillRetain())
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.UsernameFlag() && !this.PasswordFlag() {
|
|
||||||
return total, fmt.Errorf("connect/decodeMessage: Username flag is set but Password flag is not set")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(src[total:]) < 2 {
|
|
||||||
return 0, fmt.Errorf("connect/decodeMessage: Insufficient buffer size. Expecting %d, got %d.", 2, len(src[total:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
this.keepAlive = binary.BigEndian.Uint16(src[total:])
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
this.clientId, n, err = readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession to 1
|
|
||||||
if len(this.clientId) == 0 && !this.CleanSession() {
|
|
||||||
return total, ErrIdentifierRejected
|
|
||||||
}
|
|
||||||
|
|
||||||
// The ClientId must contain only characters 0-9, a-z, and A-Z
|
|
||||||
// We also support ClientId longer than 23 encoded bytes
|
|
||||||
// We do not support ClientId outside of the above characters
|
|
||||||
if len(this.clientId) > 0 && !this.validClientId(this.clientId) {
|
|
||||||
return total, ErrIdentifierRejected
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.WillFlag() {
|
|
||||||
this.willTopic, n, err = readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
this.willMessage, n, err = readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
|
||||||
// but the password string is missing.
|
|
||||||
if this.UsernameFlag() && len(src[total:]) > 0 {
|
|
||||||
this.username, n, err = readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
|
||||||
// but the password string is missing.
|
|
||||||
if this.PasswordFlag() && len(src[total:]) > 0 {
|
|
||||||
this.password, n, err = readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *ConnectMessage) msglen() int {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
verstr, ok := SupportedVersions[this.version]
|
|
||||||
if !ok {
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2 bytes protocol name length
|
|
||||||
// n bytes protocol name
|
|
||||||
// 1 byte protocol version
|
|
||||||
// 1 byte connect flags
|
|
||||||
// 2 bytes keep alive timer
|
|
||||||
total += 2 + len(verstr) + 1 + 1 + 2
|
|
||||||
|
|
||||||
// Add the clientID length, 2 is the length prefix
|
|
||||||
total += 2 + len(this.clientId)
|
|
||||||
|
|
||||||
// Add the will topic and will message length, and the length prefixes
|
|
||||||
if this.WillFlag() {
|
|
||||||
total += 2 + len(this.willTopic) + 2 + len(this.willMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the username length
|
|
||||||
// According to the 3.1 spec, it's possible that the usernameFlag is set,
|
|
||||||
// but the user name string is missing.
|
|
||||||
if this.UsernameFlag() && len(this.username) > 0 {
|
|
||||||
total += 2 + len(this.username)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add the password length
|
|
||||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
|
||||||
// but the password string is missing.
|
|
||||||
if this.PasswordFlag() && len(this.password) > 0 {
|
|
||||||
total += 2 + len(this.password)
|
|
||||||
}
|
|
||||||
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
|
|
||||||
// validClientId checks the client ID, which is a slice of bytes, to see if it's valid.
|
|
||||||
// Client ID is valid if it meets the requirement from the MQTT spec:
|
|
||||||
// The Server MUST allow ClientIds which are between 1 and 23 UTF-8 encoded bytes in length,
|
|
||||||
// and that contain only the characters
|
|
||||||
//
|
|
||||||
// "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
|
||||||
func (this *ConnectMessage) validClientId(cid []byte) bool {
|
|
||||||
// Fixed https://github.com/surgemq/surgemq/issues/4
|
|
||||||
//if len(cid) > 23 {
|
|
||||||
// return false
|
|
||||||
//}
|
|
||||||
|
|
||||||
if this.Version() == 0x3 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return clientIdRegexp.Match(cid)
|
|
||||||
}
|
|
||||||
@@ -1,373 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestConnectMessageFields(t *testing.T) {
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
|
|
||||||
err := msg.SetVersion(0x3)
|
|
||||||
require.NoError(t, err, "Error setting message version.")
|
|
||||||
|
|
||||||
require.Equal(t, 0x3, int(msg.Version()), "Incorrect version number")
|
|
||||||
|
|
||||||
err = msg.SetVersion(0x5)
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
msg.SetCleanSession(true)
|
|
||||||
require.True(t, msg.CleanSession(), "Error setting clean session flag.")
|
|
||||||
|
|
||||||
msg.SetCleanSession(false)
|
|
||||||
require.False(t, msg.CleanSession(), "Error setting clean session flag.")
|
|
||||||
|
|
||||||
msg.SetWillFlag(true)
|
|
||||||
require.True(t, msg.WillFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetWillFlag(false)
|
|
||||||
require.False(t, msg.WillFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetWillRetain(true)
|
|
||||||
require.True(t, msg.WillRetain(), "Error setting will retain.")
|
|
||||||
|
|
||||||
msg.SetWillRetain(false)
|
|
||||||
require.False(t, msg.WillRetain(), "Error setting will retain.")
|
|
||||||
|
|
||||||
msg.SetPasswordFlag(true)
|
|
||||||
require.True(t, msg.PasswordFlag(), "Error setting password flag.")
|
|
||||||
|
|
||||||
msg.SetPasswordFlag(false)
|
|
||||||
require.False(t, msg.PasswordFlag(), "Error setting password flag.")
|
|
||||||
|
|
||||||
msg.SetUsernameFlag(true)
|
|
||||||
require.True(t, msg.UsernameFlag(), "Error setting username flag.")
|
|
||||||
|
|
||||||
msg.SetUsernameFlag(false)
|
|
||||||
require.False(t, msg.UsernameFlag(), "Error setting username flag.")
|
|
||||||
|
|
||||||
msg.SetWillQos(1)
|
|
||||||
require.Equal(t, 1, int(msg.WillQos()), "Error setting will QoS.")
|
|
||||||
|
|
||||||
err = msg.SetWillQos(4)
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
err = msg.SetClientId([]byte("j0j0jfajf02j0asdjf"))
|
|
||||||
require.NoError(t, err, "Error setting client ID")
|
|
||||||
|
|
||||||
require.Equal(t, "j0j0jfajf02j0asdjf", string(msg.ClientId()), "Error setting client ID.")
|
|
||||||
|
|
||||||
err = msg.SetClientId([]byte("this is good for v3"))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
msg.SetVersion(0x4)
|
|
||||||
|
|
||||||
err = msg.SetClientId([]byte("this is no good for v4!"))
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
msg.SetVersion(0x3)
|
|
||||||
|
|
||||||
msg.SetWillTopic([]byte("willtopic"))
|
|
||||||
require.Equal(t, "willtopic", string(msg.WillTopic()), "Error setting will topic.")
|
|
||||||
|
|
||||||
require.True(t, msg.WillFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetWillTopic([]byte(""))
|
|
||||||
require.Equal(t, "", string(msg.WillTopic()), "Error setting will topic.")
|
|
||||||
|
|
||||||
require.False(t, msg.WillFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetWillMessage([]byte("this is a will message"))
|
|
||||||
require.Equal(t, "this is a will message", string(msg.WillMessage()), "Error setting will message.")
|
|
||||||
|
|
||||||
require.True(t, msg.WillFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetWillMessage([]byte(""))
|
|
||||||
require.Equal(t, "", string(msg.WillMessage()), "Error setting will topic.")
|
|
||||||
|
|
||||||
require.False(t, msg.WillFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetWillTopic([]byte("willtopic"))
|
|
||||||
msg.SetWillMessage([]byte("this is a will message"))
|
|
||||||
msg.SetWillTopic([]byte(""))
|
|
||||||
require.True(t, msg.WillFlag(), "Error setting will topic.")
|
|
||||||
|
|
||||||
msg.SetUsername([]byte("myname"))
|
|
||||||
require.Equal(t, "myname", string(msg.Username()), "Error setting will message.")
|
|
||||||
|
|
||||||
require.True(t, msg.UsernameFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetUsername([]byte(""))
|
|
||||||
require.Equal(t, "", string(msg.Username()), "Error setting will message.")
|
|
||||||
|
|
||||||
require.False(t, msg.UsernameFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetPassword([]byte("myname"))
|
|
||||||
require.Equal(t, "myname", string(msg.Password()), "Error setting will message.")
|
|
||||||
|
|
||||||
require.True(t, msg.PasswordFlag(), "Error setting will flag.")
|
|
||||||
|
|
||||||
msg.SetPassword([]byte(""))
|
|
||||||
require.Equal(t, "", string(msg.Password()), "Error setting will message.")
|
|
||||||
|
|
||||||
require.False(t, msg.PasswordFlag(), "Error setting will flag.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNECT << 4),
|
|
||||||
60,
|
|
||||||
0, // Length MSB (0)
|
|
||||||
4, // Length LSB (4)
|
|
||||||
'M', 'Q', 'T', 'T',
|
|
||||||
4, // Protocol level 4
|
|
||||||
206, // connect flags 11001110, will QoS = 01
|
|
||||||
0, // Keep Alive MSB (0)
|
|
||||||
10, // Keep Alive LSB (10)
|
|
||||||
0, // Client ID MSB (0)
|
|
||||||
7, // Client ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Will Topic MSB (0)
|
|
||||||
4, // Will Topic LSB (4)
|
|
||||||
'w', 'i', 'l', 'l',
|
|
||||||
0, // Will Message MSB (0)
|
|
||||||
12, // Will Message LSB (12)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
0, // Username ID MSB (0)
|
|
||||||
7, // Username ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Password ID MSB (0)
|
|
||||||
10, // Password ID LSB (10)
|
|
||||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, 206, int(msg.connectFlags), "Incorrect flag value.")
|
|
||||||
require.Equal(t, 10, int(msg.KeepAlive()), "Incorrect KeepAlive value.")
|
|
||||||
require.Equal(t, "surgemq", string(msg.ClientId()), "Incorrect client ID value.")
|
|
||||||
require.Equal(t, "will", string(msg.WillTopic()), "Incorrect will topic value.")
|
|
||||||
require.Equal(t, "send me home", string(msg.WillMessage()), "Incorrect will message value.")
|
|
||||||
require.Equal(t, "surgemq", string(msg.Username()), "Incorrect username value.")
|
|
||||||
require.Equal(t, "verysecret", string(msg.Password()), "Incorrect password value.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectMessageDecode2(t *testing.T) {
|
|
||||||
// missing last byte 't'
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNECT << 4),
|
|
||||||
60,
|
|
||||||
0, // Length MSB (0)
|
|
||||||
4, // Length LSB (4)
|
|
||||||
'M', 'Q', 'T', 'T',
|
|
||||||
4, // Protocol level 4
|
|
||||||
206, // connect flags 11001110, will QoS = 01
|
|
||||||
0, // Keep Alive MSB (0)
|
|
||||||
10, // Keep Alive LSB (10)
|
|
||||||
0, // Client ID MSB (0)
|
|
||||||
7, // Client ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Will Topic MSB (0)
|
|
||||||
4, // Will Topic LSB (4)
|
|
||||||
'w', 'i', 'l', 'l',
|
|
||||||
0, // Will Message MSB (0)
|
|
||||||
12, // Will Message LSB (12)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
0, // Username ID MSB (0)
|
|
||||||
7, // Username ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Password ID MSB (0)
|
|
||||||
10, // Password ID LSB (10)
|
|
||||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectMessageDecode3(t *testing.T) {
|
|
||||||
// extra bytes
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNECT << 4),
|
|
||||||
60,
|
|
||||||
0, // Length MSB (0)
|
|
||||||
4, // Length LSB (4)
|
|
||||||
'M', 'Q', 'T', 'T',
|
|
||||||
4, // Protocol level 4
|
|
||||||
206, // connect flags 11001110, will QoS = 01
|
|
||||||
0, // Keep Alive MSB (0)
|
|
||||||
10, // Keep Alive LSB (10)
|
|
||||||
0, // Client ID MSB (0)
|
|
||||||
7, // Client ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Will Topic MSB (0)
|
|
||||||
4, // Will Topic LSB (4)
|
|
||||||
'w', 'i', 'l', 'l',
|
|
||||||
0, // Will Message MSB (0)
|
|
||||||
12, // Will Message LSB (12)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
0, // Username ID MSB (0)
|
|
||||||
7, // Username ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Password ID MSB (0)
|
|
||||||
10, // Password ID LSB (10)
|
|
||||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
|
||||||
'e', 'x', 't', 'r', 'a',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 62, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectMessageDecode4(t *testing.T) {
|
|
||||||
// missing client Id, clean session == 0
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNECT << 4),
|
|
||||||
53,
|
|
||||||
0, // Length MSB (0)
|
|
||||||
4, // Length LSB (4)
|
|
||||||
'M', 'Q', 'T', 'T',
|
|
||||||
4, // Protocol level 4
|
|
||||||
204, // connect flags 11001110, will QoS = 01
|
|
||||||
0, // Keep Alive MSB (0)
|
|
||||||
10, // Keep Alive LSB (10)
|
|
||||||
0, // Client ID MSB (0)
|
|
||||||
0, // Client ID LSB (0)
|
|
||||||
0, // Will Topic MSB (0)
|
|
||||||
4, // Will Topic LSB (4)
|
|
||||||
'w', 'i', 'l', 'l',
|
|
||||||
0, // Will Message MSB (0)
|
|
||||||
12, // Will Message LSB (12)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
0, // Username ID MSB (0)
|
|
||||||
7, // Username ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Password ID MSB (0)
|
|
||||||
10, // Password ID LSB (10)
|
|
||||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnectMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNECT << 4),
|
|
||||||
60,
|
|
||||||
0, // Length MSB (0)
|
|
||||||
4, // Length LSB (4)
|
|
||||||
'M', 'Q', 'T', 'T',
|
|
||||||
4, // Protocol level 4
|
|
||||||
206, // connect flags 11001110, will QoS = 01
|
|
||||||
0, // Keep Alive MSB (0)
|
|
||||||
10, // Keep Alive LSB (10)
|
|
||||||
0, // Client ID MSB (0)
|
|
||||||
7, // Client ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Will Topic MSB (0)
|
|
||||||
4, // Will Topic LSB (4)
|
|
||||||
'w', 'i', 'l', 'l',
|
|
||||||
0, // Will Message MSB (0)
|
|
||||||
12, // Will Message LSB (12)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
0, // Username ID MSB (0)
|
|
||||||
7, // Username ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Password ID MSB (0)
|
|
||||||
10, // Password ID LSB (10)
|
|
||||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
msg.SetWillQos(1)
|
|
||||||
msg.SetVersion(4)
|
|
||||||
msg.SetCleanSession(true)
|
|
||||||
msg.SetClientId([]byte("surgemq"))
|
|
||||||
msg.SetKeepAlive(10)
|
|
||||||
msg.SetWillTopic([]byte("will"))
|
|
||||||
msg.SetWillMessage([]byte("send me home"))
|
|
||||||
msg.SetUsername([]byte("surgemq"))
|
|
||||||
msg.SetPassword([]byte("verysecret"))
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestConnectDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(CONNECT << 4),
|
|
||||||
60,
|
|
||||||
0, // Length MSB (0)
|
|
||||||
4, // Length LSB (4)
|
|
||||||
'M', 'Q', 'T', 'T',
|
|
||||||
4, // Protocol level 4
|
|
||||||
206, // connect flags 11001110, will QoS = 01
|
|
||||||
0, // Keep Alive MSB (0)
|
|
||||||
10, // Keep Alive LSB (10)
|
|
||||||
0, // Client ID MSB (0)
|
|
||||||
7, // Client ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Will Topic MSB (0)
|
|
||||||
4, // Will Topic LSB (4)
|
|
||||||
'w', 'i', 'l', 'l',
|
|
||||||
0, // Will Message MSB (0)
|
|
||||||
12, // Will Message LSB (12)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
0, // Username ID MSB (0)
|
|
||||||
7, // Username ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Password ID MSB (0)
|
|
||||||
10, // Password ID LSB (10)
|
|
||||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// The DISCONNECT Packet is the final Control Packet sent from the Client to the Server.
|
|
||||||
// It indicates that the Client is disconnecting cleanly.
|
|
||||||
type DisconnectMessage struct {
|
|
||||||
header
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*DisconnectMessage)(nil)
|
|
||||||
|
|
||||||
// NewDisconnectMessage creates a new DISCONNECT message.
|
|
||||||
func NewDisconnectMessage() *DisconnectMessage {
|
|
||||||
msg := &DisconnectMessage{}
|
|
||||||
msg.SetType(DISCONNECT)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *DisconnectMessage) Decode(src []byte) (int, error) {
|
|
||||||
return this.header.decode(src)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *DisconnectMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("disconnect/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.encode(dst)
|
|
||||||
}
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestDisconnectMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(DISCONNECT << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewDisconnectMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, DISCONNECT, msg.Type(), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDisconnectMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(DISCONNECT << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewDisconnectMessage()
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestDisconnectDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(DISCONNECT << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewDisconnectMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,141 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
/*
|
|
||||||
Package message is an encoder/decoder library for MQTT 3.1 and 3.1.1 messages. You can
|
|
||||||
find the MQTT specs at the following locations:
|
|
||||||
|
|
||||||
3.1.1 - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/
|
|
||||||
3.1 - http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
|
|
||||||
|
|
||||||
From the spec:
|
|
||||||
|
|
||||||
MQTT is a Client Server publish/subscribe messaging transport protocol. It is
|
|
||||||
light weight, open, simple, and designed so as to be easy to implement. These
|
|
||||||
characteristics make it ideal for use in many situations, including constrained
|
|
||||||
environments such as for communication in Machine to Machine (M2M) and Internet
|
|
||||||
of Things (IoT) contexts where a small code footprint is required and/or network
|
|
||||||
bandwidth is at a premium.
|
|
||||||
|
|
||||||
The MQTT protocol works by exchanging a series of MQTT messages in a defined way.
|
|
||||||
The protocol runs over TCP/IP, or over other network protocols that provide
|
|
||||||
ordered, lossless, bi-directional connections.
|
|
||||||
|
|
||||||
|
|
||||||
There are two main items to take note in this package. The first is
|
|
||||||
|
|
||||||
type MessageType byte
|
|
||||||
|
|
||||||
MessageType is the type representing the MQTT packet types. In the MQTT spec, MQTT
|
|
||||||
control packet type is represented as a 4-bit unsigned value. MessageType receives
|
|
||||||
several methods that returns string representations of the names and descriptions.
|
|
||||||
|
|
||||||
Also, one of the methods is New(). It returns a new Message object based on the mtype
|
|
||||||
parameter. For example:
|
|
||||||
|
|
||||||
m, err := CONNECT.New()
|
|
||||||
msg := m.(*ConnectMessage)
|
|
||||||
|
|
||||||
This would return a PublishMessage struct, but mapped to the Message interface. You can
|
|
||||||
then type assert it back to a *PublishMessage. Another way to create a new
|
|
||||||
PublishMessage is to call
|
|
||||||
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
|
|
||||||
Every message type has a New function that returns a new message. The list of available
|
|
||||||
message types are defined as constants below.
|
|
||||||
|
|
||||||
As you may have noticed, the second important item is the Message interface. It defines
|
|
||||||
several methods that are common to all messages, including Name(), Desc(), and Type().
|
|
||||||
Most importantly, it also defines the Encode() and Decode() methods.
|
|
||||||
|
|
||||||
Encode() (io.Reader, int, error)
|
|
||||||
Decode(io.Reader) (int, error)
|
|
||||||
|
|
||||||
Encode returns an io.Reader in which the encoded bytes can be read. The second return
|
|
||||||
value is the number of bytes encoded, so the caller knows how many bytes there will be.
|
|
||||||
If Encode returns an error, then the first two return values should be considered invalid.
|
|
||||||
Any changes to the message after Encode() is called will invalidate the io.Reader.
|
|
||||||
|
|
||||||
Decode reads from the io.Reader parameter until a full message is decoded, or when io.Reader
|
|
||||||
returns EOF or error. The first return value is the number of bytes read from io.Reader.
|
|
||||||
The second is error if Decode encounters any problems.
|
|
||||||
|
|
||||||
With these in mind, we can now do:
|
|
||||||
|
|
||||||
// Create a new CONNECT message
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
|
|
||||||
// Set the appropriate parameters
|
|
||||||
msg.SetWillQos(1)
|
|
||||||
msg.SetVersion(4)
|
|
||||||
msg.SetCleanSession(true)
|
|
||||||
msg.SetClientId([]byte("surgemq"))
|
|
||||||
msg.SetKeepAlive(10)
|
|
||||||
msg.SetWillTopic([]byte("will"))
|
|
||||||
msg.SetWillMessage([]byte("send me home"))
|
|
||||||
msg.SetUsername([]byte("surgemq"))
|
|
||||||
msg.SetPassword([]byte("verysecret"))
|
|
||||||
|
|
||||||
// Encode the message and get the io.Reader
|
|
||||||
r, n, err := msg.Encode()
|
|
||||||
if err == nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write n bytes into the connection
|
|
||||||
m, err := io.CopyN(conn, r, int64(n))
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
fmt.Printf("Sent %d bytes of %s message", m, msg.Name())
|
|
||||||
|
|
||||||
To receive a CONNECT message from a connection, we can do:
|
|
||||||
|
|
||||||
// Create a new CONNECT message
|
|
||||||
msg := NewConnectMessage()
|
|
||||||
|
|
||||||
// Decode the message by reading from conn
|
|
||||||
n, err := msg.Decode(conn)
|
|
||||||
|
|
||||||
If you don't know what type of message is coming down the pipe, you can do something like this:
|
|
||||||
|
|
||||||
// Create a buffered IO reader for the connection
|
|
||||||
br := bufio.NewReader(conn)
|
|
||||||
|
|
||||||
// Peek at the first byte, which contains the message type
|
|
||||||
b, err := br.Peek(1)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Extract the type from the first byte
|
|
||||||
t := MessageType(b[0] >> 4)
|
|
||||||
|
|
||||||
// Create a new message
|
|
||||||
msg, err := t.New()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode it from the bufio.Reader
|
|
||||||
n, err := msg.Decode(br)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
*/
|
|
||||||
package message
|
|
||||||
@@ -1,248 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
gPacketId uint64 = 0
|
|
||||||
)
|
|
||||||
|
|
||||||
// Fixed header
|
|
||||||
// - 1 byte for control packet type (bits 7-4) and flags (bits 3-0)
|
|
||||||
// - up to 4 byte for remaining length
|
|
||||||
type header struct {
|
|
||||||
// Header fields
|
|
||||||
//mtype MessageType
|
|
||||||
//flags byte
|
|
||||||
remlen int32
|
|
||||||
|
|
||||||
// mtypeflags is the first byte of the buffer, 4 bits for mtype, 4 bits for flags
|
|
||||||
mtypeflags []byte
|
|
||||||
|
|
||||||
// Some messages need packet ID, 2 byte uint16
|
|
||||||
packetId []byte
|
|
||||||
|
|
||||||
// Points to the decoding buffer
|
|
||||||
dbuf []byte
|
|
||||||
|
|
||||||
// Whether the message has changed since last decode
|
|
||||||
dirty bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a string representation of the message.
|
|
||||||
func (this header) String() string {
|
|
||||||
return fmt.Sprintf("Type=%q, Flags=%08b, Remaining Length=%d", this.Type().Name(), this.Flags(), this.remlen)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name returns a string representation of the message type. Examples include
|
|
||||||
// "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of
|
|
||||||
// the message types and cannot be changed.
|
|
||||||
func (this *header) Name() string {
|
|
||||||
return this.Type().Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Desc returns a string description of the message type. For example, a
|
|
||||||
// CONNECT message would return "Client request to connect to Server." These
|
|
||||||
// descriptions are statically defined (copied from the MQTT spec) and cannot
|
|
||||||
// be changed.
|
|
||||||
func (this *header) Desc() string {
|
|
||||||
return this.Type().Desc()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Type returns the MessageType of the Message. The retured value should be one
|
|
||||||
// of the constants defined for MessageType.
|
|
||||||
func (this *header) Type() MessageType {
|
|
||||||
//return this.mtype
|
|
||||||
if len(this.mtypeflags) != 1 {
|
|
||||||
this.mtypeflags = make([]byte, 1)
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return MessageType(this.mtypeflags[0] >> 4)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetType sets the message type of this message. It also correctly sets the
|
|
||||||
// default flags for the message type. It returns an error if the type is invalid.
|
|
||||||
func (this *header) SetType(mtype MessageType) error {
|
|
||||||
if !mtype.Valid() {
|
|
||||||
return fmt.Errorf("header/SetType: Invalid control packet type %d", mtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notice we don't set the message to be dirty when we are not allocating a new
|
|
||||||
// buffer. In this case, it means the buffer is probably a sub-slice of another
|
|
||||||
// slice. If that's the case, then during encoding we would have copied the whole
|
|
||||||
// backing buffer anyway.
|
|
||||||
if len(this.mtypeflags) != 1 {
|
|
||||||
this.mtypeflags = make([]byte, 1)
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
this.mtypeflags[0] = byte(mtype)<<4 | (mtype.DefaultFlags() & 0xf)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flags returns the fixed header flags for this message.
|
|
||||||
func (this *header) Flags() byte {
|
|
||||||
//return this.flags
|
|
||||||
return this.mtypeflags[0] & 0x0f
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemainingLength returns the length of the non-fixed-header part of the message.
|
|
||||||
func (this *header) RemainingLength() int32 {
|
|
||||||
return this.remlen
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRemainingLength sets the length of the non-fixed-header part of the message.
|
|
||||||
// It returns error if the length is greater than 268435455, which is the max
|
|
||||||
// message length as defined by the MQTT spec.
|
|
||||||
func (this *header) SetRemainingLength(remlen int32) error {
|
|
||||||
if remlen > maxRemainingLength || remlen < 0 {
|
|
||||||
return fmt.Errorf("header/SetLength: Remaining length (%d) out of bound (max %d, min 0)", remlen, maxRemainingLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.remlen = remlen
|
|
||||||
this.dirty = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *header) Len() int {
|
|
||||||
return this.msglen()
|
|
||||||
}
|
|
||||||
|
|
||||||
// PacketId returns the ID of the packet.
|
|
||||||
func (this *header) PacketId() uint16 {
|
|
||||||
if len(this.packetId) == 2 {
|
|
||||||
return binary.BigEndian.Uint16(this.packetId)
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPacketId sets the ID of the packet.
|
|
||||||
func (this *header) SetPacketId(v uint16) {
|
|
||||||
// If setting to 0, nothing to do, move on
|
|
||||||
if v == 0 {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// If packetId buffer is not 2 bytes (uint16), then we allocate a new one and
|
|
||||||
// make dirty. Then we encode the packet ID into the buffer.
|
|
||||||
if len(this.packetId) != 2 {
|
|
||||||
this.packetId = make([]byte, 2)
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Notice we don't set the message to be dirty when we are not allocating a new
|
|
||||||
// buffer. In this case, it means the buffer is probably a sub-slice of another
|
|
||||||
// slice. If that's the case, then during encoding we would have copied the whole
|
|
||||||
// backing buffer anyway.
|
|
||||||
binary.BigEndian.PutUint16(this.packetId, v)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *header) encode(dst []byte) (int, error) {
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if len(dst) < ml {
|
|
||||||
return 0, fmt.Errorf("header/Encode: Insufficient buffer size. Expecting %d, got %d.", ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
if this.remlen > maxRemainingLength || this.remlen < 0 {
|
|
||||||
return total, fmt.Errorf("header/Encode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
if !this.Type().Valid() {
|
|
||||||
return total, fmt.Errorf("header/Encode: Invalid message type %d", this.Type())
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[total] = this.mtypeflags[0]
|
|
||||||
total += 1
|
|
||||||
|
|
||||||
n := binary.PutUvarint(dst[total:], uint64(this.remlen))
|
|
||||||
total += n
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode reads from the io.Reader parameter until a full message is decoded, or
|
|
||||||
// when io.Reader returns EOF or error. The first return value is the number of
|
|
||||||
// bytes read from io.Reader. The second is error if Decode encounters any problems.
|
|
||||||
func (this *header) decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
this.dbuf = src
|
|
||||||
|
|
||||||
mtype := this.Type()
|
|
||||||
//mtype := MessageType(0)
|
|
||||||
|
|
||||||
this.mtypeflags = src[total : total+1]
|
|
||||||
//mtype := MessageType(src[total] >> 4)
|
|
||||||
if !this.Type().Valid() {
|
|
||||||
return total, fmt.Errorf("header/Decode: Invalid message type %d.", mtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
if mtype != this.Type() {
|
|
||||||
return total, fmt.Errorf("header/Decode: Invalid message type %d. Expecting %d.", this.Type(), mtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
//this.flags = src[total] & 0x0f
|
|
||||||
if this.Type() != PUBLISH && this.Flags() != this.Type().DefaultFlags() {
|
|
||||||
return total, fmt.Errorf("header/Decode: Invalid message (%d) flags. Expecting %d, got %d", this.Type(), this.Type().DefaultFlags(), this.Flags())
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.Type() == PUBLISH && !ValidQos((this.Flags()>>1)&0x3) {
|
|
||||||
return total, fmt.Errorf("header/Decode: Invalid QoS (%d) for PUBLISH message.", (this.Flags()>>1)&0x3)
|
|
||||||
}
|
|
||||||
|
|
||||||
total++
|
|
||||||
|
|
||||||
remlen, m := binary.Uvarint(src[total:])
|
|
||||||
total += m
|
|
||||||
this.remlen = int32(remlen)
|
|
||||||
|
|
||||||
if this.remlen > maxRemainingLength || remlen < 0 {
|
|
||||||
return total, fmt.Errorf("header/Decode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength)
|
|
||||||
}
|
|
||||||
|
|
||||||
if int(this.remlen) > len(src[total:]) {
|
|
||||||
return total, fmt.Errorf("header/Decode: Remaining length (%d) is greater than remaining buffer (%d)", this.remlen, len(src[total:]))
|
|
||||||
}
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *header) msglen() int {
|
|
||||||
// message type and flag byte
|
|
||||||
total := 1
|
|
||||||
|
|
||||||
if this.remlen <= 127 {
|
|
||||||
total += 1
|
|
||||||
} else if this.remlen <= 16383 {
|
|
||||||
total += 2
|
|
||||||
} else if this.remlen <= 2097151 {
|
|
||||||
total += 3
|
|
||||||
} else {
|
|
||||||
total += 4
|
|
||||||
}
|
|
||||||
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
@@ -1,182 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMessageHeaderFields(t *testing.T) {
|
|
||||||
header := &header{}
|
|
||||||
|
|
||||||
header.SetRemainingLength(33)
|
|
||||||
|
|
||||||
require.Equal(t, int32(33), header.RemainingLength())
|
|
||||||
|
|
||||||
err := header.SetRemainingLength(268435456)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
err = header.SetRemainingLength(-1)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
err = header.SetType(RESERVED)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
err = header.SetType(PUBREL)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, PUBREL, header.Type())
|
|
||||||
require.Equal(t, "PUBREL", header.Name())
|
|
||||||
require.Equal(t, 2, int(header.Flags()))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Not enough bytes
|
|
||||||
func TestMessageHeaderDecode(t *testing.T) {
|
|
||||||
buf := []byte{0x6f, 193, 2}
|
|
||||||
header := &header{}
|
|
||||||
|
|
||||||
_, err := header.decode(buf)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Remaining length too big
|
|
||||||
func TestMessageHeaderDecode2(t *testing.T) {
|
|
||||||
buf := []byte{0x62, 0xff, 0xff, 0xff, 0xff}
|
|
||||||
header := &header{}
|
|
||||||
|
|
||||||
_, err := header.decode(buf)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageHeaderDecode3(t *testing.T) {
|
|
||||||
buf := []byte{0x62, 0xff}
|
|
||||||
header := &header{}
|
|
||||||
|
|
||||||
_, err := header.decode(buf)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageHeaderDecode4(t *testing.T) {
|
|
||||||
buf := []byte{0x62, 0xff, 0xff, 0xff, 0x7f}
|
|
||||||
header := &header{
|
|
||||||
mtypeflags: []byte{6<<4 | 2},
|
|
||||||
//mtype: 6,
|
|
||||||
//flags: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := header.decode(buf)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Equal(t, 5, n)
|
|
||||||
require.Equal(t, maxRemainingLength, header.RemainingLength())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageHeaderDecode5(t *testing.T) {
|
|
||||||
buf := []byte{0x62, 0xff, 0x7f}
|
|
||||||
header := &header{
|
|
||||||
mtypeflags: []byte{6<<4 | 2},
|
|
||||||
//mtype: 6,
|
|
||||||
//flags: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err := header.decode(buf)
|
|
||||||
require.Error(t, err)
|
|
||||||
require.Equal(t, 3, n)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageHeaderEncode1(t *testing.T) {
|
|
||||||
header := &header{}
|
|
||||||
headerBytes := []byte{0x62, 193, 2}
|
|
||||||
|
|
||||||
err := header.SetType(PUBREL)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = header.SetRemainingLength(321)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
buf := make([]byte, 3)
|
|
||||||
n, err := header.encode(buf)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 3, n)
|
|
||||||
require.Equal(t, headerBytes, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageHeaderEncode2(t *testing.T) {
|
|
||||||
header := &header{}
|
|
||||||
|
|
||||||
err := header.SetType(PUBREL)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
header.remlen = 268435456
|
|
||||||
|
|
||||||
buf := make([]byte, 5)
|
|
||||||
_, err = header.encode(buf)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageHeaderEncode3(t *testing.T) {
|
|
||||||
header := &header{}
|
|
||||||
headerBytes := []byte{0x62, 0xff, 0xff, 0xff, 0x7f}
|
|
||||||
|
|
||||||
err := header.SetType(PUBREL)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = header.SetRemainingLength(maxRemainingLength)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
buf := make([]byte, 5)
|
|
||||||
n, err := header.encode(buf)
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 5, n)
|
|
||||||
require.Equal(t, headerBytes, buf)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageHeaderEncode4(t *testing.T) {
|
|
||||||
header := &header{
|
|
||||||
mtypeflags: []byte{byte(RESERVED2) << 4},
|
|
||||||
//mtype: 6,
|
|
||||||
//flags: 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
buf := make([]byte, 5)
|
|
||||||
_, err := header.encode(buf)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
/*
|
|
||||||
// This test is to ensure that an empty message is at least 2 bytes long
|
|
||||||
func TestMessageHeaderEncode5(t *testing.T) {
|
|
||||||
msg := NewPingreqMessage()
|
|
||||||
|
|
||||||
dst, n, err := msg.encode()
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("Error encoding PINGREQ message: %v", err)
|
|
||||||
} else if n != 2 {
|
|
||||||
t.Errorf("Incorrect result. Expecting length of 2 bytes, got %d.", dst.(*bytes.Buffer).Len())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
@@ -1,410 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
maxLPString uint16 = 65535
|
|
||||||
maxFixedHeaderLength int = 5
|
|
||||||
maxRemainingLength int32 = 268435455 // bytes, or 256 MB
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// QoS 0: At most once delivery
|
|
||||||
// The message is delivered according to the capabilities of the underlying network.
|
|
||||||
// No response is sent by the receiver and no retry is performed by the sender. The
|
|
||||||
// message arrives at the receiver either once or not at all.
|
|
||||||
QosAtMostOnce byte = iota
|
|
||||||
|
|
||||||
// QoS 1: At least once delivery
|
|
||||||
// This quality of service ensures that the message arrives at the receiver at least once.
|
|
||||||
// A QoS 1 PUBLISH Packet has a Packet Identifier in its variable header and is acknowledged
|
|
||||||
// by a PUBACK Packet. Section 2.3.1 provides more information about Packet Identifiers.
|
|
||||||
QosAtLeastOnce
|
|
||||||
|
|
||||||
// QoS 2: Exactly once delivery
|
|
||||||
// This is the highest quality of service, for use when neither loss nor duplication of
|
|
||||||
// messages are acceptable. There is an increased overhead associated with this quality of
|
|
||||||
// service.
|
|
||||||
QosExactlyOnce
|
|
||||||
|
|
||||||
// QosFailure is a return value for a subscription if there's a problem while subscribing
|
|
||||||
// to a specific topic.
|
|
||||||
QosFailure = 0x80
|
|
||||||
)
|
|
||||||
|
|
||||||
// SupportedVersions is a map of the version number (0x3 or 0x4) to the version string,
|
|
||||||
// "MQIsdp" for 0x3, and "MQTT" for 0x4.
|
|
||||||
var SupportedVersions map[byte]string = map[byte]string{
|
|
||||||
0x3: "MQIsdp",
|
|
||||||
0x4: "MQTT",
|
|
||||||
}
|
|
||||||
|
|
||||||
// MessageType is the type representing the MQTT packet types. In the MQTT spec,
|
|
||||||
// MQTT control packet type is represented as a 4-bit unsigned value.
|
|
||||||
type MessageType byte
|
|
||||||
|
|
||||||
// Message is an interface defined for all MQTT message types.
|
|
||||||
type Message interface {
|
|
||||||
// Name returns a string representation of the message type. Examples include
|
|
||||||
// "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of
|
|
||||||
// the message types and cannot be changed.
|
|
||||||
Name() string
|
|
||||||
|
|
||||||
// Desc returns a string description of the message type. For example, a
|
|
||||||
// CONNECT message would return "Client request to connect to Server." These
|
|
||||||
// descriptions are statically defined (copied from the MQTT spec) and cannot
|
|
||||||
// be changed.
|
|
||||||
Desc() string
|
|
||||||
|
|
||||||
// Type returns the MessageType of the Message. The retured value should be one
|
|
||||||
// of the constants defined for MessageType.
|
|
||||||
Type() MessageType
|
|
||||||
|
|
||||||
// PacketId returns the packet ID of the Message. The retured value is 0 if
|
|
||||||
// there's no packet ID for this message type. Otherwise non-0.
|
|
||||||
PacketId() uint16
|
|
||||||
|
|
||||||
// Encode writes the message bytes into the byte array from the argument. It
|
|
||||||
// returns the number of bytes encoded and whether there's any errors along
|
|
||||||
// the way. If there's any errors, then the byte slice and count should be
|
|
||||||
// considered invalid.
|
|
||||||
Encode([]byte) (int, error)
|
|
||||||
|
|
||||||
// Decode reads the bytes in the byte slice from the argument. It returns the
|
|
||||||
// total number of bytes decoded, and whether there's any errors during the
|
|
||||||
// process. The byte slice MUST NOT be modified during the duration of this
|
|
||||||
// message being available since the byte slice is internally stored for
|
|
||||||
// references.
|
|
||||||
Decode([]byte) (int, error)
|
|
||||||
|
|
||||||
Len() int
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
// RESERVED is a reserved value and should be considered an invalid message type
|
|
||||||
RESERVED MessageType = iota
|
|
||||||
|
|
||||||
// CONNECT: Client to Server. Client request to connect to Server.
|
|
||||||
CONNECT
|
|
||||||
|
|
||||||
// CONNACK: Server to Client. Connect acknowledgement.
|
|
||||||
CONNACK
|
|
||||||
|
|
||||||
// PUBLISH: Client to Server, or Server to Client. Publish message.
|
|
||||||
PUBLISH
|
|
||||||
|
|
||||||
// PUBACK: Client to Server, or Server to Client. Publish acknowledgment for
|
|
||||||
// QoS 1 messages.
|
|
||||||
PUBACK
|
|
||||||
|
|
||||||
// PUBACK: Client to Server, or Server to Client. Publish received for QoS 2 messages.
|
|
||||||
// Assured delivery part 1.
|
|
||||||
PUBREC
|
|
||||||
|
|
||||||
// PUBREL: Client to Server, or Server to Client. Publish release for QoS 2 messages.
|
|
||||||
// Assured delivery part 1.
|
|
||||||
PUBREL
|
|
||||||
|
|
||||||
// PUBCOMP: Client to Server, or Server to Client. Publish complete for QoS 2 messages.
|
|
||||||
// Assured delivery part 3.
|
|
||||||
PUBCOMP
|
|
||||||
|
|
||||||
// SUBSCRIBE: Client to Server. Client subscribe request.
|
|
||||||
SUBSCRIBE
|
|
||||||
|
|
||||||
// SUBACK: Server to Client. Subscribe acknowledgement.
|
|
||||||
SUBACK
|
|
||||||
|
|
||||||
// UNSUBSCRIBE: Client to Server. Unsubscribe request.
|
|
||||||
UNSUBSCRIBE
|
|
||||||
|
|
||||||
// UNSUBACK: Server to Client. Unsubscribe acknowlegment.
|
|
||||||
UNSUBACK
|
|
||||||
|
|
||||||
// PINGREQ: Client to Server. PING request.
|
|
||||||
PINGREQ
|
|
||||||
|
|
||||||
// PINGRESP: Server to Client. PING response.
|
|
||||||
PINGRESP
|
|
||||||
|
|
||||||
// DISCONNECT: Client to Server. Client is disconnecting.
|
|
||||||
DISCONNECT
|
|
||||||
|
|
||||||
// RESERVED2 is a reserved value and should be considered an invalid message type.
|
|
||||||
RESERVED2
|
|
||||||
)
|
|
||||||
|
|
||||||
func (this MessageType) String() string {
|
|
||||||
return this.Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Name returns the name of the message type. It should correspond to one of the
|
|
||||||
// constant values defined for MessageType. It is statically defined and cannot
|
|
||||||
// be changed.
|
|
||||||
func (this MessageType) Name() string {
|
|
||||||
switch this {
|
|
||||||
case RESERVED:
|
|
||||||
return "RESERVED"
|
|
||||||
case CONNECT:
|
|
||||||
return "CONNECT"
|
|
||||||
case CONNACK:
|
|
||||||
return "CONNACK"
|
|
||||||
case PUBLISH:
|
|
||||||
return "PUBLISH"
|
|
||||||
case PUBACK:
|
|
||||||
return "PUBACK"
|
|
||||||
case PUBREC:
|
|
||||||
return "PUBREC"
|
|
||||||
case PUBREL:
|
|
||||||
return "PUBREL"
|
|
||||||
case PUBCOMP:
|
|
||||||
return "PUBCOMP"
|
|
||||||
case SUBSCRIBE:
|
|
||||||
return "SUBSCRIBE"
|
|
||||||
case SUBACK:
|
|
||||||
return "SUBACK"
|
|
||||||
case UNSUBSCRIBE:
|
|
||||||
return "UNSUBSCRIBE"
|
|
||||||
case UNSUBACK:
|
|
||||||
return "UNSUBACK"
|
|
||||||
case PINGREQ:
|
|
||||||
return "PINGREQ"
|
|
||||||
case PINGRESP:
|
|
||||||
return "PINGRESP"
|
|
||||||
case DISCONNECT:
|
|
||||||
return "DISCONNECT"
|
|
||||||
case RESERVED2:
|
|
||||||
return "RESERVED2"
|
|
||||||
}
|
|
||||||
|
|
||||||
return "UNKNOWN"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Desc returns the description of the message type. It is statically defined (copied
|
|
||||||
// from MQTT spec) and cannot be changed.
|
|
||||||
func (this MessageType) Desc() string {
|
|
||||||
switch this {
|
|
||||||
case RESERVED:
|
|
||||||
return "Reserved"
|
|
||||||
case CONNECT:
|
|
||||||
return "Client request to connect to Server"
|
|
||||||
case CONNACK:
|
|
||||||
return "Connect acknowledgement"
|
|
||||||
case PUBLISH:
|
|
||||||
return "Publish message"
|
|
||||||
case PUBACK:
|
|
||||||
return "Publish acknowledgement"
|
|
||||||
case PUBREC:
|
|
||||||
return "Publish received (assured delivery part 1)"
|
|
||||||
case PUBREL:
|
|
||||||
return "Publish release (assured delivery part 2)"
|
|
||||||
case PUBCOMP:
|
|
||||||
return "Publish complete (assured delivery part 3)"
|
|
||||||
case SUBSCRIBE:
|
|
||||||
return "Client subscribe request"
|
|
||||||
case SUBACK:
|
|
||||||
return "Subscribe acknowledgement"
|
|
||||||
case UNSUBSCRIBE:
|
|
||||||
return "Unsubscribe request"
|
|
||||||
case UNSUBACK:
|
|
||||||
return "Unsubscribe acknowledgement"
|
|
||||||
case PINGREQ:
|
|
||||||
return "PING request"
|
|
||||||
case PINGRESP:
|
|
||||||
return "PING response"
|
|
||||||
case DISCONNECT:
|
|
||||||
return "Client is disconnecting"
|
|
||||||
case RESERVED2:
|
|
||||||
return "Reserved"
|
|
||||||
}
|
|
||||||
|
|
||||||
return "UNKNOWN"
|
|
||||||
}
|
|
||||||
|
|
||||||
// DefaultFlags returns the default flag values for the message type, as defined by
|
|
||||||
// the MQTT spec.
|
|
||||||
func (this MessageType) DefaultFlags() byte {
|
|
||||||
switch this {
|
|
||||||
case RESERVED:
|
|
||||||
return 0
|
|
||||||
case CONNECT:
|
|
||||||
return 0
|
|
||||||
case CONNACK:
|
|
||||||
return 0
|
|
||||||
case PUBLISH:
|
|
||||||
return 0
|
|
||||||
case PUBACK:
|
|
||||||
return 0
|
|
||||||
case PUBREC:
|
|
||||||
return 0
|
|
||||||
case PUBREL:
|
|
||||||
return 2
|
|
||||||
case PUBCOMP:
|
|
||||||
return 0
|
|
||||||
case SUBSCRIBE:
|
|
||||||
return 2
|
|
||||||
case SUBACK:
|
|
||||||
return 0
|
|
||||||
case UNSUBSCRIBE:
|
|
||||||
return 2
|
|
||||||
case UNSUBACK:
|
|
||||||
return 0
|
|
||||||
case PINGREQ:
|
|
||||||
return 0
|
|
||||||
case PINGRESP:
|
|
||||||
return 0
|
|
||||||
case DISCONNECT:
|
|
||||||
return 0
|
|
||||||
case RESERVED2:
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// New creates a new message based on the message type. It is a shortcut to call
|
|
||||||
// one of the New*Message functions. If an error is returned then the message type
|
|
||||||
// is invalid.
|
|
||||||
func (this MessageType) New() (Message, error) {
|
|
||||||
switch this {
|
|
||||||
case CONNECT:
|
|
||||||
return NewConnectMessage(), nil
|
|
||||||
case CONNACK:
|
|
||||||
return NewConnackMessage(), nil
|
|
||||||
case PUBLISH:
|
|
||||||
return NewPublishMessage(), nil
|
|
||||||
case PUBACK:
|
|
||||||
return NewPubackMessage(), nil
|
|
||||||
case PUBREC:
|
|
||||||
return NewPubrecMessage(), nil
|
|
||||||
case PUBREL:
|
|
||||||
return NewPubrelMessage(), nil
|
|
||||||
case PUBCOMP:
|
|
||||||
return NewPubcompMessage(), nil
|
|
||||||
case SUBSCRIBE:
|
|
||||||
return NewSubscribeMessage(), nil
|
|
||||||
case SUBACK:
|
|
||||||
return NewSubackMessage(), nil
|
|
||||||
case UNSUBSCRIBE:
|
|
||||||
return NewUnsubscribeMessage(), nil
|
|
||||||
case UNSUBACK:
|
|
||||||
return NewUnsubackMessage(), nil
|
|
||||||
case PINGREQ:
|
|
||||||
return NewPingreqMessage(), nil
|
|
||||||
case PINGRESP:
|
|
||||||
return NewPingrespMessage(), nil
|
|
||||||
case DISCONNECT:
|
|
||||||
return NewDisconnectMessage(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil, fmt.Errorf("msgtype/NewMessage: Invalid message type %d", this)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Valid returns a boolean indicating whether the message type is valid or not.
|
|
||||||
func (this MessageType) Valid() bool {
|
|
||||||
return this > RESERVED && this < RESERVED2
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidTopic checks the topic, which is a slice of bytes, to see if it's valid. Topic is
|
|
||||||
// considered valid if it's longer than 0 bytes, and doesn't contain any wildcard characters
|
|
||||||
// such as + and #.
|
|
||||||
func ValidTopic(topic []byte) bool {
|
|
||||||
return len(topic) > 0 && bytes.IndexByte(topic, '#') == -1 && bytes.IndexByte(topic, '+') == -1
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidQos checks the QoS value to see if it's valid. Valid QoS are QosAtMostOnce,
|
|
||||||
// QosAtLeastonce, and QosExactlyOnce.
|
|
||||||
func ValidQos(qos byte) bool {
|
|
||||||
return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidVersion checks to see if the version is valid. Current supported versions include 0x3 and 0x4.
|
|
||||||
func ValidVersion(v byte) bool {
|
|
||||||
_, ok := SupportedVersions[v]
|
|
||||||
return ok
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidConnackError checks to see if the error is a Connack Error or not
|
|
||||||
func ValidConnackError(err error) bool {
|
|
||||||
return err == ErrInvalidProtocolVersion || err == ErrIdentifierRejected ||
|
|
||||||
err == ErrServerUnavailable || err == ErrBadUsernameOrPassword || err == ErrNotAuthorized
|
|
||||||
}
|
|
||||||
|
|
||||||
func ValidConnackErrorEx(err error) (bool, ConnackCode) {
|
|
||||||
if err == ErrInvalidProtocolVersion {
|
|
||||||
return true, ErrInvalidProtocolVersion
|
|
||||||
}
|
|
||||||
if err == ErrIdentifierRejected {
|
|
||||||
return true, ErrIdentifierRejected
|
|
||||||
}
|
|
||||||
if err == ErrServerUnavailable {
|
|
||||||
return true, ErrServerUnavailable
|
|
||||||
}
|
|
||||||
if err == ErrBadUsernameOrPassword {
|
|
||||||
return true, ErrBadUsernameOrPassword
|
|
||||||
}
|
|
||||||
if err == ErrNotAuthorized {
|
|
||||||
return true, ErrNotAuthorized
|
|
||||||
}
|
|
||||||
return false, ConnectionAccepted
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read length prefixed bytes
|
|
||||||
func readLPBytes(buf []byte) ([]byte, int, error) {
|
|
||||||
if len(buf) < 2 {
|
|
||||||
return nil, 0, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2, len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
n, total := 0, 0
|
|
||||||
|
|
||||||
n = int(binary.BigEndian.Uint16(buf))
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
if len(buf) < n {
|
|
||||||
return nil, total, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", n, len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
total += n
|
|
||||||
|
|
||||||
return buf[2:total], total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write length prefixed bytes
|
|
||||||
func writeLPBytes(buf []byte, b []byte) (int, error) {
|
|
||||||
total, n := 0, len(b)
|
|
||||||
|
|
||||||
if n > int(maxLPString) {
|
|
||||||
return 0, fmt.Errorf("utils/writeLPBytes: Length (%d) greater than %d bytes.", n, maxLPString)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(buf) < 2+n {
|
|
||||||
return 0, fmt.Errorf("utils/writeLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2+n, len(buf))
|
|
||||||
}
|
|
||||||
|
|
||||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
copy(buf[total:], b)
|
|
||||||
total += n
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
@@ -1,178 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
lpstrings []string = []string{
|
|
||||||
"this is a test",
|
|
||||||
"hope it succeeds",
|
|
||||||
"but just in case",
|
|
||||||
"send me your millions",
|
|
||||||
"",
|
|
||||||
}
|
|
||||||
|
|
||||||
lpstringBytes []byte = []byte{
|
|
||||||
0x0, 0xe, 't', 'h', 'i', 's', ' ', 'i', 's', ' ', 'a', ' ', 't', 'e', 's', 't',
|
|
||||||
0x0, 0x10, 'h', 'o', 'p', 'e', ' ', 'i', 't', ' ', 's', 'u', 'c', 'c', 'e', 'e', 'd', 's',
|
|
||||||
0x0, 0x10, 'b', 'u', 't', ' ', 'j', 'u', 's', 't', ' ', 'i', 'n', ' ', 'c', 'a', 's', 'e',
|
|
||||||
0x0, 0x15, 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'y', 'o', 'u', 'r', ' ', 'm', 'i', 'l', 'l', 'i', 'o', 'n', 's',
|
|
||||||
0x0, 0x0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msgBytes []byte = []byte{
|
|
||||||
byte(CONNECT << 4),
|
|
||||||
60,
|
|
||||||
0, // Length MSB (0)
|
|
||||||
4, // Length LSB (4)
|
|
||||||
'M', 'Q', 'T', 'T',
|
|
||||||
4, // Protocol level 4
|
|
||||||
206, // connect flags 11001110, will QoS = 01
|
|
||||||
0, // Keep Alive MSB (0)
|
|
||||||
10, // Keep Alive LSB (10)
|
|
||||||
0, // Client ID MSB (0)
|
|
||||||
7, // Client ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Will Topic MSB (0)
|
|
||||||
4, // Will Topic LSB (4)
|
|
||||||
'w', 'i', 'l', 'l',
|
|
||||||
0, // Will Message MSB (0)
|
|
||||||
12, // Will Message LSB (12)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
0, // Username ID MSB (0)
|
|
||||||
7, // Username ID LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // Password ID MSB (0)
|
|
||||||
10, // Password ID LSB (10)
|
|
||||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestReadLPBytes(t *testing.T) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
for _, str := range lpstrings {
|
|
||||||
b, n, err := readLPBytes(lpstringBytes[total:])
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, str, string(b))
|
|
||||||
require.Equal(t, len(str)+2, n)
|
|
||||||
|
|
||||||
total += n
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWriteLPBytes(t *testing.T) {
|
|
||||||
total := 0
|
|
||||||
buf := make([]byte, 1000)
|
|
||||||
|
|
||||||
for _, str := range lpstrings {
|
|
||||||
n, err := writeLPBytes(buf[total:], []byte(str))
|
|
||||||
|
|
||||||
require.NoError(t, err)
|
|
||||||
require.Equal(t, 2+len(str), n)
|
|
||||||
|
|
||||||
total += n
|
|
||||||
}
|
|
||||||
|
|
||||||
require.Equal(t, lpstringBytes, buf[:total])
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestMessageTypes(t *testing.T) {
|
|
||||||
if CONNECT != 1 ||
|
|
||||||
CONNACK != 2 ||
|
|
||||||
PUBLISH != 3 ||
|
|
||||||
PUBACK != 4 ||
|
|
||||||
PUBREC != 5 ||
|
|
||||||
PUBREL != 6 ||
|
|
||||||
PUBCOMP != 7 ||
|
|
||||||
SUBSCRIBE != 8 ||
|
|
||||||
SUBACK != 9 ||
|
|
||||||
UNSUBSCRIBE != 10 ||
|
|
||||||
UNSUBACK != 11 ||
|
|
||||||
PINGREQ != 12 ||
|
|
||||||
PINGRESP != 13 ||
|
|
||||||
DISCONNECT != 14 {
|
|
||||||
|
|
||||||
t.Errorf("Message types have invalid code")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQosCodes(t *testing.T) {
|
|
||||||
if QosAtMostOnce != 0 || QosAtLeastOnce != 1 || QosExactlyOnce != 2 {
|
|
||||||
t.Errorf("QOS codes invalid")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConnackReturnCodes(t *testing.T) {
|
|
||||||
require.Equal(t, ErrInvalidProtocolVersion.Error(), ConnackCode(1).Error(), "Incorrect ConnackCode error value.")
|
|
||||||
|
|
||||||
require.Equal(t, ErrIdentifierRejected.Error(), ConnackCode(2).Error(), "Incorrect ConnackCode error value.")
|
|
||||||
|
|
||||||
require.Equal(t, ErrServerUnavailable.Error(), ConnackCode(3).Error(), "Incorrect ConnackCode error value.")
|
|
||||||
|
|
||||||
require.Equal(t, ErrBadUsernameOrPassword.Error(), ConnackCode(4).Error(), "Incorrect ConnackCode error value.")
|
|
||||||
|
|
||||||
require.Equal(t, ErrNotAuthorized.Error(), ConnackCode(5).Error(), "Incorrect ConnackCode error value.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFixedHeaderFlags(t *testing.T) {
|
|
||||||
type detail struct {
|
|
||||||
name string
|
|
||||||
flags byte
|
|
||||||
}
|
|
||||||
|
|
||||||
details := map[MessageType]detail{
|
|
||||||
RESERVED: detail{"RESERVED", 0},
|
|
||||||
CONNECT: detail{"CONNECT", 0},
|
|
||||||
CONNACK: detail{"CONNACK", 0},
|
|
||||||
PUBLISH: detail{"PUBLISH", 0},
|
|
||||||
PUBACK: detail{"PUBACK", 0},
|
|
||||||
PUBREC: detail{"PUBREC", 0},
|
|
||||||
PUBREL: detail{"PUBREL", 2},
|
|
||||||
PUBCOMP: detail{"PUBCOMP", 0},
|
|
||||||
SUBSCRIBE: detail{"SUBSCRIBE", 2},
|
|
||||||
SUBACK: detail{"SUBACK", 0},
|
|
||||||
UNSUBSCRIBE: detail{"UNSUBSCRIBE", 2},
|
|
||||||
UNSUBACK: detail{"UNSUBACK", 0},
|
|
||||||
PINGREQ: detail{"PINGREQ", 0},
|
|
||||||
PINGRESP: detail{"PINGRESP", 0},
|
|
||||||
DISCONNECT: detail{"DISCONNECT", 0},
|
|
||||||
RESERVED2: detail{"RESERVED2", 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for m, d := range details {
|
|
||||||
if m.Name() != d.name {
|
|
||||||
t.Errorf("Name mismatch. Expecting %s, got %s.", d.name, m.Name())
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.DefaultFlags() != d.flags {
|
|
||||||
t.Errorf("Flag mismatch for %s. Expecting %d, got %d.", m.Name(), d.flags, m.DefaultFlags())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSupportedVersions(t *testing.T) {
|
|
||||||
for k, v := range SupportedVersions {
|
|
||||||
if k == 0x03 && v != "MQIsdp" {
|
|
||||||
t.Errorf("Protocol version and name mismatch. Expect %s, got %s.", "MQIsdp", v)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPingreqMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PINGREQ << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPingreqMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, PINGREQ, msg.Type(), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPingreqMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PINGREQ << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPingreqMessage()
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPingrespMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PINGRESP << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPingrespMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, PINGRESP, msg.Type(), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPingrespMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PINGRESP << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPingrespMessage()
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestPingreqDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PINGREQ << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPingreqMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestPingrespDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PINGRESP << 4),
|
|
||||||
0,
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPingrespMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
// The PINGREQ Packet is sent from a Client to the Server. It can be used to:
|
|
||||||
// 1. Indicate to the Server that the Client is alive in the absence of any other
|
|
||||||
// Control Packets being sent from the Client to the Server.
|
|
||||||
// 2. Request that the Server responds to confirm that it is alive.
|
|
||||||
// 3. Exercise the network to indicate that the Network Connection is active.
|
|
||||||
type PingreqMessage struct {
|
|
||||||
DisconnectMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*PingreqMessage)(nil)
|
|
||||||
|
|
||||||
// NewPingreqMessage creates a new PINGREQ message.
|
|
||||||
func NewPingreqMessage() *PingreqMessage {
|
|
||||||
msg := &PingreqMessage{}
|
|
||||||
msg.SetType(PINGREQ)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
// A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ
|
|
||||||
// Packet. It indicates that the Server is alive.
|
|
||||||
type PingrespMessage struct {
|
|
||||||
DisconnectMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*PingrespMessage)(nil)
|
|
||||||
|
|
||||||
// NewPingrespMessage creates a new PINGRESP message.
|
|
||||||
func NewPingrespMessage() *PingrespMessage {
|
|
||||||
msg := &PingrespMessage{}
|
|
||||||
msg.SetType(PINGRESP)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
@@ -1,109 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// A PUBACK Packet is the response to a PUBLISH Packet with QoS level 1.
|
|
||||||
type PubackMessage struct {
|
|
||||||
header
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*PubackMessage)(nil)
|
|
||||||
|
|
||||||
// NewPubackMessage creates a new PUBACK message.
|
|
||||||
func NewPubackMessage() *PubackMessage {
|
|
||||||
msg := &PubackMessage{}
|
|
||||||
msg.SetType(PUBACK)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this PubackMessage) String() string {
|
|
||||||
return fmt.Sprintf("%s, Packet ID=%d", this.header, this.packetId)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PubackMessage) Len() int {
|
|
||||||
if !this.dirty {
|
|
||||||
return len(this.dbuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.msglen() + ml
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PubackMessage) Decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.decode(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
|
||||||
this.packetId = src[total : total+2]
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
this.dirty = false
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PubackMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("puback/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
hl := this.header.msglen()
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if len(dst) < hl+ml {
|
|
||||||
return 0, fmt.Errorf("puback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.encode(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if copy(dst[total:total+2], this.packetId) != 2 {
|
|
||||||
dst[total], dst[total+1] = 0, 0
|
|
||||||
}
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PubackMessage) msglen() int {
|
|
||||||
// packet ID
|
|
||||||
return 2
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPubackMessageFields(t *testing.T) {
|
|
||||||
msg := NewPubackMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubackMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBACK << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubackMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, PUBACK, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test insufficient bytes
|
|
||||||
func TestPubackMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBACK << 4),
|
|
||||||
2,
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubackMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubackMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBACK << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubackMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestPubackDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBACK << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubackMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
// The PUBCOMP Packet is the response to a PUBREL Packet. It is the fourth and
|
|
||||||
// final packet of the QoS 2 protocol exchange.
|
|
||||||
type PubcompMessage struct {
|
|
||||||
PubackMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*PubcompMessage)(nil)
|
|
||||||
|
|
||||||
// NewPubcompMessage creates a new PUBCOMP message.
|
|
||||||
func NewPubcompMessage() *PubcompMessage {
|
|
||||||
msg := &PubcompMessage{}
|
|
||||||
msg.SetType(PUBCOMP)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPubcompMessageFields(t *testing.T) {
|
|
||||||
msg := NewPubcompMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubcompMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBCOMP << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubcompMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, PUBCOMP, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test insufficient bytes
|
|
||||||
func TestPubcompMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBCOMP << 4),
|
|
||||||
2,
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubcompMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubcompMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBCOMP << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubcompMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestPubcompDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBCOMP << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubcompMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,250 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A PUBLISH Control Packet is sent from a Client to a Server or from Server to a Client
|
|
||||||
// to transport an Application Message.
|
|
||||||
type PublishMessage struct {
|
|
||||||
header
|
|
||||||
|
|
||||||
topic []byte
|
|
||||||
payload []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*PublishMessage)(nil)
|
|
||||||
|
|
||||||
// NewPublishMessage creates a new PUBLISH message.
|
|
||||||
func NewPublishMessage() *PublishMessage {
|
|
||||||
msg := &PublishMessage{}
|
|
||||||
msg.SetType(PUBLISH)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this PublishMessage) String() string {
|
|
||||||
return fmt.Sprintf("%s, Topic=%q, Packet ID=%d, QoS=%d, Retained=%t, Dup=%t, Payload=%v",
|
|
||||||
this.header, this.topic, this.packetId, this.QoS(), this.Retain(), this.Dup(), this.payload)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Dup returns the value specifying the duplicate delivery of a PUBLISH Control Packet.
|
|
||||||
// If the DUP flag is set to 0, it indicates that this is the first occasion that the
|
|
||||||
// Client or Server has attempted to send this MQTT PUBLISH Packet. If the DUP flag is
|
|
||||||
// set to 1, it indicates that this might be re-delivery of an earlier attempt to send
|
|
||||||
// the Packet.
|
|
||||||
func (this *PublishMessage) Dup() bool {
|
|
||||||
return ((this.Flags() >> 3) & 0x1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetDup sets the value specifying the duplicate delivery of a PUBLISH Control Packet.
|
|
||||||
func (this *PublishMessage) SetDup(v bool) {
|
|
||||||
if v {
|
|
||||||
this.mtypeflags[0] |= 0x8 // 00001000
|
|
||||||
} else {
|
|
||||||
this.mtypeflags[0] &= 247 // 11110111
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retain returns the value of the RETAIN flag. This flag is only used on the PUBLISH
|
|
||||||
// Packet. If the RETAIN flag is set to 1, in a PUBLISH Packet sent by a Client to a
|
|
||||||
// Server, the Server MUST store the Application Message and its QoS, so that it can be
|
|
||||||
// delivered to future subscribers whose subscriptions match its topic name.
|
|
||||||
func (this *PublishMessage) Retain() bool {
|
|
||||||
return (this.Flags() & 0x1) == 1
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetRetain sets the value of the RETAIN flag.
|
|
||||||
func (this *PublishMessage) SetRetain(v bool) {
|
|
||||||
if v {
|
|
||||||
this.mtypeflags[0] |= 0x1 // 00000001
|
|
||||||
} else {
|
|
||||||
this.mtypeflags[0] &= 254 // 11111110
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// QoS returns the field that indicates the level of assurance for delivery of an
|
|
||||||
// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce.
|
|
||||||
func (this *PublishMessage) QoS() byte {
|
|
||||||
return (this.Flags() >> 1) & 0x3
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetQoS sets the field that indicates the level of assurance for delivery of an
|
|
||||||
// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce.
|
|
||||||
// An error is returned if the value is not one of these.
|
|
||||||
func (this *PublishMessage) SetQoS(v byte) error {
|
|
||||||
if v != 0x0 && v != 0x1 && v != 0x2 {
|
|
||||||
return fmt.Errorf("publish/SetQoS: Invalid QoS %d.", v)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.mtypeflags[0] = (this.mtypeflags[0] & 249) | (v << 1) // 249 = 11111001
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Topic returns the the topic name that identifies the information channel to which
|
|
||||||
// payload data is published.
|
|
||||||
func (this *PublishMessage) Topic() []byte {
|
|
||||||
return this.topic
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTopic sets the the topic name that identifies the information channel to which
|
|
||||||
// payload data is published. An error is returned if ValidTopic() is falbase.
|
|
||||||
func (this *PublishMessage) SetTopic(v []byte) error {
|
|
||||||
if !ValidTopic(v) {
|
|
||||||
return fmt.Errorf("publish/SetTopic: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(v))
|
|
||||||
}
|
|
||||||
|
|
||||||
this.topic = v
|
|
||||||
this.dirty = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Payload returns the application message that's part of the PUBLISH message.
|
|
||||||
func (this *PublishMessage) Payload() []byte {
|
|
||||||
return this.payload
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetPayload sets the application message that's part of the PUBLISH message.
|
|
||||||
func (this *PublishMessage) SetPayload(v []byte) {
|
|
||||||
this.payload = v
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PublishMessage) Len() int {
|
|
||||||
if !this.dirty {
|
|
||||||
return len(this.dbuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.msglen() + ml
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PublishMessage) Decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
hn, err := this.header.decode(src[total:])
|
|
||||||
total += hn
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n := 0
|
|
||||||
|
|
||||||
this.topic, n, err = readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !ValidTopic(this.topic) {
|
|
||||||
return total, fmt.Errorf("publish/Decode: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(this.topic))
|
|
||||||
}
|
|
||||||
|
|
||||||
// The packet identifier field is only present in the PUBLISH packets where the
|
|
||||||
// QoS level is 1 or 2
|
|
||||||
if this.QoS() != 0 {
|
|
||||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
|
||||||
this.packetId = src[total : total+2]
|
|
||||||
total += 2
|
|
||||||
}
|
|
||||||
|
|
||||||
l := int(this.remlen) - (total - hn)
|
|
||||||
this.payload = src[total : total+l]
|
|
||||||
total += len(this.payload)
|
|
||||||
|
|
||||||
this.dirty = false
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PublishMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("publish/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(this.topic) == 0 {
|
|
||||||
return 0, fmt.Errorf("publish/Encode: Topic name is empty.")
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(this.payload) == 0 {
|
|
||||||
return 0, fmt.Errorf("publish/Encode: Payload is empty.")
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
hl := this.header.msglen()
|
|
||||||
|
|
||||||
if len(dst) < hl+ml {
|
|
||||||
return 0, fmt.Errorf("publish/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.encode(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
n, err = writeLPBytes(dst[total:], this.topic)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// The packet identifier field is only present in the PUBLISH packets where the QoS level is 1 or 2
|
|
||||||
if this.QoS() != 0 {
|
|
||||||
if this.PacketId() == 0 {
|
|
||||||
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
|
|
||||||
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
|
|
||||||
}
|
|
||||||
|
|
||||||
n = copy(dst[total:], this.packetId)
|
|
||||||
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
|
|
||||||
total += n
|
|
||||||
}
|
|
||||||
|
|
||||||
copy(dst[total:], this.payload)
|
|
||||||
total += len(this.payload)
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *PublishMessage) msglen() int {
|
|
||||||
total := 2 + len(this.topic) + len(this.payload)
|
|
||||||
if this.QoS() != 0 {
|
|
||||||
total += 2
|
|
||||||
}
|
|
||||||
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
@@ -1,281 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPublishMessageHeaderFields(t *testing.T) {
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
msg.mtypeflags[0] |= 11
|
|
||||||
|
|
||||||
require.True(t, msg.Dup(), "Incorrect DUP flag.")
|
|
||||||
require.True(t, msg.Retain(), "Incorrect RETAIN flag.")
|
|
||||||
require.Equal(t, 1, int(msg.QoS()), "Incorrect QoS.")
|
|
||||||
|
|
||||||
msg.SetDup(false)
|
|
||||||
|
|
||||||
require.False(t, msg.Dup(), "Incorrect DUP flag.")
|
|
||||||
|
|
||||||
msg.SetRetain(false)
|
|
||||||
|
|
||||||
require.False(t, msg.Retain(), "Incorrect RETAIN flag.")
|
|
||||||
|
|
||||||
err := msg.SetQoS(2)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error setting QoS.")
|
|
||||||
require.Equal(t, 2, int(msg.QoS()), "Incorrect QoS.")
|
|
||||||
|
|
||||||
err = msg.SetQoS(3)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
err = msg.SetQoS(0)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error setting QoS.")
|
|
||||||
require.Equal(t, 0, int(msg.QoS()), "Incorrect QoS.")
|
|
||||||
|
|
||||||
msg.SetDup(true)
|
|
||||||
|
|
||||||
require.True(t, msg.Dup(), "Incorrect DUP flag.")
|
|
||||||
|
|
||||||
msg.SetRetain(true)
|
|
||||||
|
|
||||||
require.True(t, msg.Retain(), "Incorrect RETAIN flag.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPublishMessageFields(t *testing.T) {
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
|
|
||||||
msg.SetTopic([]byte("coolstuff"))
|
|
||||||
|
|
||||||
require.Equal(t, "coolstuff", string(msg.Topic()), "Error setting message topic.")
|
|
||||||
|
|
||||||
err := msg.SetTopic([]byte("coolstuff/#"))
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting acket ID.")
|
|
||||||
|
|
||||||
msg.SetPayload([]byte("this is a payload to be sent"))
|
|
||||||
|
|
||||||
require.Equal(t, []byte("this is a payload to be sent"), msg.Payload(), "Error setting payload.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPublishMessageDecode1(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBLISH<<4) | 2,
|
|
||||||
23,
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
|
||||||
require.Equal(t, "surgemq", string(msg.Topic()), "Error deocding topic name.")
|
|
||||||
require.Equal(t, []byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'}, msg.Payload(), "Error deocding payload.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test insufficient bytes
|
|
||||||
func TestPublishMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBLISH<<4) | 2,
|
|
||||||
26,
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// test qos = 0 and no client id
|
|
||||||
func TestPublishMessageDecode3(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBLISH << 4),
|
|
||||||
21,
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPublishMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBLISH<<4) | 2,
|
|
||||||
23,
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
msg.SetTopic([]byte("surgemq"))
|
|
||||||
msg.SetQoS(1)
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test empty topic name
|
|
||||||
func TestPublishMessageEncode2(t *testing.T) {
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
msg.SetTopic([]byte(""))
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
_, err := msg.Encode(dst)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// test encoding qos = 0 and no packet id
|
|
||||||
func TestPublishMessageEncode3(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBLISH << 4),
|
|
||||||
21,
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
msg.SetTopic([]byte("surgemq"))
|
|
||||||
msg.SetQoS(0)
|
|
||||||
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test large message
|
|
||||||
func TestPublishMessageEncode4(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBLISH << 4),
|
|
||||||
137,
|
|
||||||
8,
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
}
|
|
||||||
|
|
||||||
payload := make([]byte, 1024)
|
|
||||||
msgBytes = append(msgBytes, payload...)
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
msg.SetTopic([]byte("surgemq"))
|
|
||||||
msg.SetQoS(0)
|
|
||||||
msg.SetPayload(payload)
|
|
||||||
|
|
||||||
require.Equal(t, len(msgBytes), msg.Len())
|
|
||||||
|
|
||||||
dst := make([]byte, 1100)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test from github issue #2, @mrdg
|
|
||||||
func TestPublishDecodeEncodeEquiv2(t *testing.T) {
|
|
||||||
msgBytes := []byte{50, 18, 0, 9, 103, 114, 101, 101, 116, 105, 110, 103, 115, 0, 1, 72, 101, 108, 108, 111}
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestPublishDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBLISH<<4) | 2,
|
|
||||||
23,
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPublishMessage()
|
|
||||||
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
type PubrecMessage struct {
|
|
||||||
PubackMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
// A PUBREC Packet is the response to a PUBLISH Packet with QoS 2. It is the second
|
|
||||||
// packet of the QoS 2 protocol exchange.
|
|
||||||
var _ Message = (*PubrecMessage)(nil)
|
|
||||||
|
|
||||||
// NewPubrecMessage creates a new PUBREC message.
|
|
||||||
func NewPubrecMessage() *PubrecMessage {
|
|
||||||
msg := &PubrecMessage{}
|
|
||||||
msg.SetType(PUBREC)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPubrecMessageFields(t *testing.T) {
|
|
||||||
msg := NewPubrecMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubrecMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREC << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrecMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, PUBREC, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test insufficient bytes
|
|
||||||
func TestPubrecMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREC << 4),
|
|
||||||
2,
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrecMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubrecMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREC << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrecMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestPubrecDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREC << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrecMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
// A PUBREL Packet is the response to a PUBREC Packet. It is the third packet of the
|
|
||||||
// QoS 2 protocol exchange.
|
|
||||||
type PubrelMessage struct {
|
|
||||||
PubackMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*PubrelMessage)(nil)
|
|
||||||
|
|
||||||
// NewPubrelMessage creates a new PUBREL message.
|
|
||||||
func NewPubrelMessage() *PubrelMessage {
|
|
||||||
msg := &PubrelMessage{}
|
|
||||||
msg.SetType(PUBREL)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPubrelMessageFields(t *testing.T) {
|
|
||||||
msg := NewPubrelMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubrelMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREL<<4) | 2,
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrelMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, PUBREL, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test insufficient bytes
|
|
||||||
func TestPubrelMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREL<<4) | 2,
|
|
||||||
2,
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrelMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPubrelMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREL<<4) | 2,
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrelMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestPubrelDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(PUBREL<<4) | 2,
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewPubrelMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,160 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// A SUBACK Packet is sent by the Server to the Client to confirm receipt and processing
|
|
||||||
// of a SUBSCRIBE Packet.
|
|
||||||
//
|
|
||||||
// A SUBACK Packet contains a list of return codes, that specify the maximum QoS level
|
|
||||||
// that was granted in each Subscription that was requested by the SUBSCRIBE.
|
|
||||||
type SubackMessage struct {
|
|
||||||
header
|
|
||||||
|
|
||||||
returnCodes []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*SubackMessage)(nil)
|
|
||||||
|
|
||||||
// NewSubackMessage creates a new SUBACK message.
|
|
||||||
func NewSubackMessage() *SubackMessage {
|
|
||||||
msg := &SubackMessage{}
|
|
||||||
msg.SetType(SUBACK)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
// String returns a string representation of the message.
|
|
||||||
func (this SubackMessage) String() string {
|
|
||||||
return fmt.Sprintf("%s, Packet ID=%d, Return Codes=%v", this.header, this.PacketId(), this.returnCodes)
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReturnCodes returns the list of QoS returns from the subscriptions sent in the SUBSCRIBE message.
|
|
||||||
func (this *SubackMessage) ReturnCodes() []byte {
|
|
||||||
return this.returnCodes
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddReturnCodes sets the list of QoS returns from the subscriptions sent in the SUBSCRIBE message.
|
|
||||||
// An error is returned if any of the QoS values are not valid.
|
|
||||||
func (this *SubackMessage) AddReturnCodes(ret []byte) error {
|
|
||||||
for _, c := range ret {
|
|
||||||
if c != QosAtMostOnce && c != QosAtLeastOnce && c != QosExactlyOnce && c != QosFailure {
|
|
||||||
return fmt.Errorf("suback/AddReturnCode: Invalid return code %d. Must be 0, 1, 2, 0x80.", c)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.returnCodes = append(this.returnCodes, c)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddReturnCode adds a single QoS return value.
|
|
||||||
func (this *SubackMessage) AddReturnCode(ret byte) error {
|
|
||||||
return this.AddReturnCodes([]byte{ret})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubackMessage) Len() int {
|
|
||||||
if !this.dirty {
|
|
||||||
return len(this.dbuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.msglen() + ml
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubackMessage) Decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
hn, err := this.header.decode(src[total:])
|
|
||||||
total += hn
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
|
||||||
this.packetId = src[total : total+2]
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
l := int(this.remlen) - (total - hn)
|
|
||||||
this.returnCodes = src[total : total+l]
|
|
||||||
total += len(this.returnCodes)
|
|
||||||
|
|
||||||
for i, code := range this.returnCodes {
|
|
||||||
if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 {
|
|
||||||
return total, fmt.Errorf("suback/Decode: Invalid return code %d for topic %d", code, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = false
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubackMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("suback/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, code := range this.returnCodes {
|
|
||||||
if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 {
|
|
||||||
return 0, fmt.Errorf("suback/Encode: Invalid return code %d for topic %d", code, i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
hl := this.header.msglen()
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if len(dst) < hl+ml {
|
|
||||||
return 0, fmt.Errorf("suback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.encode(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if copy(dst[total:total+2], this.packetId) != 2 {
|
|
||||||
dst[total], dst[total+1] = 0, 0
|
|
||||||
}
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
copy(dst[total:], this.returnCodes)
|
|
||||||
total += len(this.returnCodes)
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubackMessage) msglen() int {
|
|
||||||
return 2 + len(this.returnCodes)
|
|
||||||
}
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSubackMessageFields(t *testing.T) {
|
|
||||||
msg := NewSubackMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
|
|
||||||
|
|
||||||
msg.AddReturnCode(1)
|
|
||||||
require.Equal(t, 1, len(msg.ReturnCodes()), "Error adding return code.")
|
|
||||||
|
|
||||||
err := msg.AddReturnCode(0x90)
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSubackMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBACK << 4),
|
|
||||||
6,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // return code 1
|
|
||||||
1, // return code 2
|
|
||||||
2, // return code 3
|
|
||||||
0x80, // return code 4
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubackMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, SUBACK, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 4, len(msg.ReturnCodes()), "Error adding return code.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test with wrong return code
|
|
||||||
func TestSubackMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBACK << 4),
|
|
||||||
6,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // return code 1
|
|
||||||
1, // return code 2
|
|
||||||
2, // return code 3
|
|
||||||
0x81, // return code 4
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubackMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSubackMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBACK << 4),
|
|
||||||
6,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // return code 1
|
|
||||||
1, // return code 2
|
|
||||||
2, // return code 3
|
|
||||||
0x80, // return code 4
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubackMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
msg.AddReturnCode(0)
|
|
||||||
msg.AddReturnCode(1)
|
|
||||||
msg.AddReturnCode(2)
|
|
||||||
msg.AddReturnCode(0x80)
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestSubackDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBACK << 4),
|
|
||||||
6,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // return code 1
|
|
||||||
1, // return code 2
|
|
||||||
2, // return code 3
|
|
||||||
0x80, // return code 4
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubackMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,253 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The SUBSCRIBE Packet is sent from the Client to the Server to create one or more
|
|
||||||
// Subscriptions. Each Subscription registers a Client’s interest in one or more
|
|
||||||
// Topics. The Server sends PUBLISH Packets to the Client in order to forward
|
|
||||||
// Application Messages that were published to Topics that match these Subscriptions.
|
|
||||||
// The SUBSCRIBE Packet also specifies (for each Subscription) the maximum QoS with
|
|
||||||
// which the Server can send Application Messages to the Client.
|
|
||||||
type SubscribeMessage struct {
|
|
||||||
header
|
|
||||||
|
|
||||||
topics [][]byte
|
|
||||||
qos []byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*SubscribeMessage)(nil)
|
|
||||||
|
|
||||||
// NewSubscribeMessage creates a new SUBSCRIBE message.
|
|
||||||
func NewSubscribeMessage() *SubscribeMessage {
|
|
||||||
msg := &SubscribeMessage{}
|
|
||||||
msg.SetType(SUBSCRIBE)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this SubscribeMessage) String() string {
|
|
||||||
msgstr := fmt.Sprintf("%s, Packet ID=%d", this.header, this.PacketId())
|
|
||||||
|
|
||||||
for i, t := range this.topics {
|
|
||||||
msgstr = fmt.Sprintf("%s, Topic[%d]=%q/%d", msgstr, i, string(t), this.qos[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
return msgstr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Topics returns a list of topics sent by the Client.
|
|
||||||
func (this *SubscribeMessage) Topics() [][]byte {
|
|
||||||
return this.topics
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTopic adds a single topic to the message, along with the corresponding QoS.
|
|
||||||
// An error is returned if QoS is invalid.
|
|
||||||
func (this *SubscribeMessage) AddTopic(topic []byte, qos byte) error {
|
|
||||||
if !ValidQos(qos) {
|
|
||||||
return fmt.Errorf("Invalid QoS %d", qos)
|
|
||||||
}
|
|
||||||
|
|
||||||
var i int
|
|
||||||
var t []byte
|
|
||||||
var found bool
|
|
||||||
|
|
||||||
for i, t = range this.topics {
|
|
||||||
if bytes.Equal(t, topic) {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
this.qos[i] = qos
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
this.topics = append(this.topics, topic)
|
|
||||||
this.qos = append(this.qos, qos)
|
|
||||||
this.dirty = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveTopic removes a single topic from the list of existing ones in the message.
|
|
||||||
// If topic does not exist it just does nothing.
|
|
||||||
func (this *SubscribeMessage) RemoveTopic(topic []byte) {
|
|
||||||
var i int
|
|
||||||
var t []byte
|
|
||||||
var found bool
|
|
||||||
|
|
||||||
for i, t = range this.topics {
|
|
||||||
if bytes.Equal(t, topic) {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
this.topics = append(this.topics[:i], this.topics[i+1:]...)
|
|
||||||
this.qos = append(this.qos[:i], this.qos[i+1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// TopicExists checks to see if a topic exists in the list.
|
|
||||||
func (this *SubscribeMessage) TopicExists(topic []byte) bool {
|
|
||||||
for _, t := range this.topics {
|
|
||||||
if bytes.Equal(t, topic) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// TopicQos returns the QoS level of a topic. If topic does not exist, QosFailure
|
|
||||||
// is returned.
|
|
||||||
func (this *SubscribeMessage) TopicQos(topic []byte) byte {
|
|
||||||
for i, t := range this.topics {
|
|
||||||
if bytes.Equal(t, topic) {
|
|
||||||
return this.qos[i]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return QosFailure
|
|
||||||
}
|
|
||||||
|
|
||||||
// Qos returns the list of QoS current in the message.
|
|
||||||
func (this *SubscribeMessage) Qos() []byte {
|
|
||||||
return this.qos
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubscribeMessage) Len() int {
|
|
||||||
if !this.dirty {
|
|
||||||
return len(this.dbuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.msglen() + ml
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubscribeMessage) Decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
hn, err := this.header.decode(src[total:])
|
|
||||||
total += hn
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
|
||||||
this.packetId = src[total : total+2]
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
remlen := int(this.remlen) - (total - hn)
|
|
||||||
for remlen > 0 {
|
|
||||||
t, n, err := readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
this.topics = append(this.topics, t)
|
|
||||||
|
|
||||||
this.qos = append(this.qos, src[total])
|
|
||||||
total++
|
|
||||||
|
|
||||||
remlen = remlen - n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(this.topics) == 0 {
|
|
||||||
return 0, fmt.Errorf("subscribe/Decode: Empty topic list")
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = false
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubscribeMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("subscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
hl := this.header.msglen()
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if len(dst) < hl+ml {
|
|
||||||
return 0, fmt.Errorf("subscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.encode(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.PacketId() == 0 {
|
|
||||||
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
|
|
||||||
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
|
|
||||||
}
|
|
||||||
|
|
||||||
n = copy(dst[total:], this.packetId)
|
|
||||||
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
|
|
||||||
total += n
|
|
||||||
|
|
||||||
for i, t := range this.topics {
|
|
||||||
n, err := writeLPBytes(dst[total:], t)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
dst[total] = this.qos[i]
|
|
||||||
total++
|
|
||||||
}
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *SubscribeMessage) msglen() int {
|
|
||||||
// packet ID
|
|
||||||
total := 2
|
|
||||||
|
|
||||||
for _, t := range this.topics {
|
|
||||||
total += 2 + len(t) + 1
|
|
||||||
}
|
|
||||||
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
@@ -1,161 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSubscribeMessageFields(t *testing.T) {
|
|
||||||
msg := NewSubscribeMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
|
|
||||||
|
|
||||||
msg.AddTopic([]byte("/a/b/#/c"), 1)
|
|
||||||
require.Equal(t, 1, len(msg.Topics()), "Error adding topic.")
|
|
||||||
|
|
||||||
require.False(t, msg.TopicExists([]byte("a/b")), "Topic should not exist.")
|
|
||||||
|
|
||||||
msg.RemoveTopic([]byte("/a/b/#/c"))
|
|
||||||
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSubscribeMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBSCRIBE<<4) | 2,
|
|
||||||
36,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // QoS
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
8, // topic name LSB (8)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
|
||||||
1, // QoS
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
10, // topic name LSB (10)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
|
||||||
2, // QoS
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubscribeMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, SUBSCRIBE, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 3, len(msg.Topics()), "Error decoding topics.")
|
|
||||||
require.True(t, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.")
|
|
||||||
require.Equal(t, 0, int(msg.TopicQos([]byte("surgemq"))), "Incorrect topic qos.")
|
|
||||||
require.True(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.")
|
|
||||||
require.Equal(t, 1, int(msg.TopicQos([]byte("/a/b/#/c"))), "Incorrect topic qos.")
|
|
||||||
require.True(t, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.")
|
|
||||||
require.Equal(t, 2, int(msg.TopicQos([]byte("/a/b/#/cdd"))), "Incorrect topic qos.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test empty topic list
|
|
||||||
func TestSubscribeMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBSCRIBE<<4) | 2,
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubscribeMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSubscribeMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBSCRIBE<<4) | 2,
|
|
||||||
36,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // QoS
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
8, // topic name LSB (8)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
|
||||||
1, // QoS
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
10, // topic name LSB (10)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
|
||||||
2, // QoS
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubscribeMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
msg.AddTopic([]byte("surgemq"), 0)
|
|
||||||
msg.AddTopic([]byte("/a/b/#/c"), 1)
|
|
||||||
msg.AddTopic([]byte("/a/b/#/cdd"), 2)
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestSubscribeDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(SUBSCRIBE<<4) | 2,
|
|
||||||
36,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // QoS
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
8, // topic name LSB (8)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
|
||||||
1, // QoS
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
10, // topic name LSB (10)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
|
||||||
2, // QoS
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewSubscribeMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
// The UNSUBACK Packet is sent by the Server to the Client to confirm receipt of an
|
|
||||||
// UNSUBSCRIBE Packet.
|
|
||||||
type UnsubackMessage struct {
|
|
||||||
PubackMessage
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*UnsubackMessage)(nil)
|
|
||||||
|
|
||||||
// NewUnsubackMessage creates a new UNSUBACK message.
|
|
||||||
func NewUnsubackMessage() *UnsubackMessage {
|
|
||||||
msg := &UnsubackMessage{}
|
|
||||||
msg.SetType(UNSUBACK)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUnsubackMessageFields(t *testing.T) {
|
|
||||||
msg := NewUnsubackMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()))
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnsubackMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBACK << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubackMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, UNSUBACK, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test insufficient bytes
|
|
||||||
func TestUnsubackMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBACK << 4),
|
|
||||||
2,
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubackMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnsubackMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBACK << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubackMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
|
|
||||||
dst := make([]byte, 10)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestUnsubackDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBACK << 4),
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubackMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -1,210 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"sync/atomic"
|
|
||||||
)
|
|
||||||
|
|
||||||
// An UNSUBSCRIBE Packet is sent by the Client to the Server, to unsubscribe from topics.
|
|
||||||
type UnsubscribeMessage struct {
|
|
||||||
header
|
|
||||||
|
|
||||||
topics [][]byte
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ Message = (*UnsubscribeMessage)(nil)
|
|
||||||
|
|
||||||
// NewUnsubscribeMessage creates a new UNSUBSCRIBE message.
|
|
||||||
func NewUnsubscribeMessage() *UnsubscribeMessage {
|
|
||||||
msg := &UnsubscribeMessage{}
|
|
||||||
msg.SetType(UNSUBSCRIBE)
|
|
||||||
|
|
||||||
return msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this UnsubscribeMessage) String() string {
|
|
||||||
msgstr := fmt.Sprintf("%s", this.header)
|
|
||||||
|
|
||||||
for i, t := range this.topics {
|
|
||||||
msgstr = fmt.Sprintf("%s, Topic%d=%s", msgstr, i, string(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
return msgstr
|
|
||||||
}
|
|
||||||
|
|
||||||
// Topics returns a list of topics sent by the Client.
|
|
||||||
func (this *UnsubscribeMessage) Topics() [][]byte {
|
|
||||||
return this.topics
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddTopic adds a single topic to the message.
|
|
||||||
func (this *UnsubscribeMessage) AddTopic(topic []byte) {
|
|
||||||
if this.TopicExists(topic) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
this.topics = append(this.topics, topic)
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveTopic removes a single topic from the list of existing ones in the message.
|
|
||||||
// If topic does not exist it just does nothing.
|
|
||||||
func (this *UnsubscribeMessage) RemoveTopic(topic []byte) {
|
|
||||||
var i int
|
|
||||||
var t []byte
|
|
||||||
var found bool
|
|
||||||
|
|
||||||
for i, t = range this.topics {
|
|
||||||
if bytes.Equal(t, topic) {
|
|
||||||
found = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if found {
|
|
||||||
this.topics = append(this.topics[:i], this.topics[i+1:]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// TopicExists checks to see if a topic exists in the list.
|
|
||||||
func (this *UnsubscribeMessage) TopicExists(topic []byte) bool {
|
|
||||||
for _, t := range this.topics {
|
|
||||||
if bytes.Equal(t, topic) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *UnsubscribeMessage) Len() int {
|
|
||||||
if !this.dirty {
|
|
||||||
return len(this.dbuf)
|
|
||||||
}
|
|
||||||
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
return this.header.msglen() + ml
|
|
||||||
}
|
|
||||||
|
|
||||||
// Decode reads from the io.Reader parameter until a full message is decoded, or
|
|
||||||
// when io.Reader returns EOF or error. The first return value is the number of
|
|
||||||
// bytes read from io.Reader. The second is error if Decode encounters any problems.
|
|
||||||
func (this *UnsubscribeMessage) Decode(src []byte) (int, error) {
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
hn, err := this.header.decode(src[total:])
|
|
||||||
total += hn
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
|
||||||
this.packetId = src[total : total+2]
|
|
||||||
total += 2
|
|
||||||
|
|
||||||
remlen := int(this.remlen) - (total - hn)
|
|
||||||
for remlen > 0 {
|
|
||||||
t, n, err := readLPBytes(src[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
this.topics = append(this.topics, t)
|
|
||||||
remlen = remlen - n - 1
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(this.topics) == 0 {
|
|
||||||
return 0, fmt.Errorf("unsubscribe/Decode: Empty topic list")
|
|
||||||
}
|
|
||||||
|
|
||||||
this.dirty = false
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Encode returns an io.Reader in which the encoded bytes can be read. The second
|
|
||||||
// return value is the number of bytes encoded, so the caller knows how many bytes
|
|
||||||
// there will be. If Encode returns an error, then the first two return values
|
|
||||||
// should be considered invalid.
|
|
||||||
// Any changes to the message after Encode() is called will invalidate the io.Reader.
|
|
||||||
func (this *UnsubscribeMessage) Encode(dst []byte) (int, error) {
|
|
||||||
if !this.dirty {
|
|
||||||
if len(dst) < len(this.dbuf) {
|
|
||||||
return 0, fmt.Errorf("unsubscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
return copy(dst, this.dbuf), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
hl := this.header.msglen()
|
|
||||||
ml := this.msglen()
|
|
||||||
|
|
||||||
if len(dst) < hl+ml {
|
|
||||||
return 0, fmt.Errorf("unsubscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
total := 0
|
|
||||||
|
|
||||||
n, err := this.header.encode(dst[total:])
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if this.PacketId() == 0 {
|
|
||||||
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
|
|
||||||
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
|
|
||||||
}
|
|
||||||
|
|
||||||
n = copy(dst[total:], this.packetId)
|
|
||||||
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
|
|
||||||
total += n
|
|
||||||
|
|
||||||
for _, t := range this.topics {
|
|
||||||
n, err := writeLPBytes(dst[total:], t)
|
|
||||||
total += n
|
|
||||||
if err != nil {
|
|
||||||
return total, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return total, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (this *UnsubscribeMessage) msglen() int {
|
|
||||||
// packet ID
|
|
||||||
total := 2
|
|
||||||
|
|
||||||
for _, t := range this.topics {
|
|
||||||
total += 2 + len(t)
|
|
||||||
}
|
|
||||||
|
|
||||||
return total
|
|
||||||
}
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
package message
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestUnsubscribeMessageFields(t *testing.T) {
|
|
||||||
msg := NewUnsubscribeMessage()
|
|
||||||
|
|
||||||
msg.SetPacketId(100)
|
|
||||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
|
|
||||||
|
|
||||||
msg.AddTopic([]byte("/a/b/#/c"))
|
|
||||||
require.Equal(t, 1, len(msg.Topics()), "Error adding topic.")
|
|
||||||
|
|
||||||
msg.AddTopic([]byte("/a/b/#/c"))
|
|
||||||
require.Equal(t, 1, len(msg.Topics()), "Error adding duplicate topic.")
|
|
||||||
|
|
||||||
msg.RemoveTopic([]byte("/a/b/#/c"))
|
|
||||||
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
|
|
||||||
|
|
||||||
require.False(t, msg.TopicExists([]byte("a/b")), "Topic should not exist.")
|
|
||||||
|
|
||||||
msg.RemoveTopic([]byte("/a/b/#/c"))
|
|
||||||
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnsubscribeMessageDecode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBSCRIBE<<4) | 2,
|
|
||||||
33,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
8, // topic name LSB (8)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
10, // topic name LSB (10)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubscribeMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, UNSUBSCRIBE, msg.Type(), "Error decoding message.")
|
|
||||||
require.Equal(t, 3, len(msg.Topics()), "Error decoding topics.")
|
|
||||||
require.True(t, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.")
|
|
||||||
require.True(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.")
|
|
||||||
require.True(t, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test empty topic list
|
|
||||||
func TestUnsubscribeMessageDecode2(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBSCRIBE<<4) | 2,
|
|
||||||
2,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubscribeMessage()
|
|
||||||
_, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestUnsubscribeMessageEncode(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBSCRIBE<<4) | 2,
|
|
||||||
33,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
8, // topic name LSB (8)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
10, // topic name LSB (10)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubscribeMessage()
|
|
||||||
msg.SetPacketId(7)
|
|
||||||
msg.AddTopic([]byte("surgemq"))
|
|
||||||
msg.AddTopic([]byte("/a/b/#/c"))
|
|
||||||
msg.AddTopic([]byte("/a/b/#/cdd"))
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
|
||||||
}
|
|
||||||
|
|
||||||
// test to ensure encoding and decoding are the same
|
|
||||||
// decode, encode, and decode again
|
|
||||||
func TestUnsubscribeDecodeEncodeEquiv(t *testing.T) {
|
|
||||||
msgBytes := []byte{
|
|
||||||
byte(UNSUBSCRIBE<<4) | 2,
|
|
||||||
33,
|
|
||||||
0, // packet ID MSB (0)
|
|
||||||
7, // packet ID LSB (7)
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
7, // topic name LSB (7)
|
|
||||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
8, // topic name LSB (8)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
|
||||||
0, // topic name MSB (0)
|
|
||||||
10, // topic name LSB (10)
|
|
||||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
|
||||||
}
|
|
||||||
|
|
||||||
msg := NewUnsubscribeMessage()
|
|
||||||
n, err := msg.Decode(msgBytes)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
|
||||||
|
|
||||||
dst := make([]byte, 100)
|
|
||||||
n2, err := msg.Encode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
|
||||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
|
||||||
|
|
||||||
n3, err := msg.Decode(dst)
|
|
||||||
|
|
||||||
require.NoError(t, err, "Error decoding message.")
|
|
||||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,62 @@
|
|||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ SessionsProvider = (*memProvider)(nil)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Register("mem", NewMemProvider())
|
||||||
|
}
|
||||||
|
|
||||||
|
type memProvider struct {
|
||||||
|
st map[string]*Session
|
||||||
|
mu sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewMemProvider() *memProvider {
|
||||||
|
return &memProvider{
|
||||||
|
st: make(map[string]*Session),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memProvider) New(id string) (*Session, error) {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
|
||||||
|
this.st[id] = &Session{id: id}
|
||||||
|
return this.st[id], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memProvider) Get(id string) (*Session, error) {
|
||||||
|
this.mu.RLock()
|
||||||
|
defer this.mu.RUnlock()
|
||||||
|
|
||||||
|
sess, ok := this.st[id]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("store/Get: No session found for key %s", id)
|
||||||
|
}
|
||||||
|
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memProvider) Del(id string) {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
delete(this.st, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memProvider) Save(id string) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memProvider) Count() int {
|
||||||
|
return len(this.st)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memProvider) Close() error {
|
||||||
|
this.st = make(map[string]*Session)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,149 @@
|
|||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// Queue size for the ack queue
|
||||||
|
defaultQueueSize = 16
|
||||||
|
)
|
||||||
|
|
||||||
|
type Session struct {
|
||||||
|
|
||||||
|
// cmsg is the CONNECT message
|
||||||
|
cmsg *packets.ConnectPacket
|
||||||
|
|
||||||
|
// Will message to publish if connect is closed unexpectedly
|
||||||
|
Will *packets.PublishPacket
|
||||||
|
|
||||||
|
// Retained publish message
|
||||||
|
Retained *packets.PublishPacket
|
||||||
|
|
||||||
|
// topics stores all the topis for this session/client
|
||||||
|
topics map[string]byte
|
||||||
|
|
||||||
|
// Initialized?
|
||||||
|
initted bool
|
||||||
|
|
||||||
|
// Serialize access to this session
|
||||||
|
mu sync.Mutex
|
||||||
|
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) Init(msg *packets.ConnectPacket) error {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
|
||||||
|
if this.initted {
|
||||||
|
return fmt.Errorf("Session already initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
this.cmsg = msg
|
||||||
|
|
||||||
|
if this.cmsg.WillFlag {
|
||||||
|
this.Will = packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||||
|
this.Will.Qos = this.cmsg.Qos
|
||||||
|
this.Will.TopicName = this.cmsg.WillTopic
|
||||||
|
this.Will.Payload = this.cmsg.WillMessage
|
||||||
|
this.Will.Retain = this.cmsg.WillRetain
|
||||||
|
}
|
||||||
|
|
||||||
|
this.topics = make(map[string]byte, 1)
|
||||||
|
|
||||||
|
this.id = string(msg.ClientIdentifier)
|
||||||
|
|
||||||
|
this.initted = true
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) Update(msg *packets.ConnectPacket) error {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
|
||||||
|
this.cmsg = msg
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) RetainMessage(msg *packets.PublishPacket) error {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
|
||||||
|
this.Retained = msg
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) AddTopic(topic string, qos byte) error {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
|
||||||
|
if !this.initted {
|
||||||
|
return fmt.Errorf("Session not yet initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
this.topics[topic] = qos
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) RemoveTopic(topic string) error {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
|
||||||
|
if !this.initted {
|
||||||
|
return fmt.Errorf("Session not yet initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(this.topics, topic)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) Topics() ([]string, []byte, error) {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
|
||||||
|
if !this.initted {
|
||||||
|
return nil, nil, fmt.Errorf("Session not yet initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
topics []string
|
||||||
|
qoss []byte
|
||||||
|
)
|
||||||
|
|
||||||
|
for k, v := range this.topics {
|
||||||
|
topics = append(topics, k)
|
||||||
|
qoss = append(qoss, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
return topics, qoss, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) ID() string {
|
||||||
|
return this.cmsg.ClientIdentifier
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) WillFlag() bool {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
return this.cmsg.WillFlag
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) SetWillFlag(v bool) {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
this.cmsg.WillFlag = v
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Session) CleanSession() bool {
|
||||||
|
this.mu.Lock()
|
||||||
|
defer this.mu.Unlock()
|
||||||
|
return this.cmsg.CleanSession
|
||||||
|
}
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
package sessions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"encoding/base64"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrSessionsProviderNotFound = errors.New("Session: Session provider not found")
|
||||||
|
ErrKeyNotAvailable = errors.New("Session: not item found for key.")
|
||||||
|
|
||||||
|
providers = make(map[string]SessionsProvider)
|
||||||
|
)
|
||||||
|
|
||||||
|
type SessionsProvider interface {
|
||||||
|
New(id string) (*Session, error)
|
||||||
|
Get(id string) (*Session, error)
|
||||||
|
Del(id string)
|
||||||
|
Save(id string) error
|
||||||
|
Count() int
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register makes a session provider available by the provided name.
|
||||||
|
// If a Register is called twice with the same name or if the driver is nil,
|
||||||
|
// it panics.
|
||||||
|
func Register(name string, provider SessionsProvider) {
|
||||||
|
if provider == nil {
|
||||||
|
panic("session: Register provide is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, dup := providers[name]; dup {
|
||||||
|
panic("session: Register called twice for provider " + name)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers[name] = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
func Unregister(name string) {
|
||||||
|
delete(providers, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
p SessionsProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(providerName string) (*Manager, error) {
|
||||||
|
p, ok := providers[providerName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("session: unknown provider %q", providerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Manager{p: p}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) New(id string) (*Session, error) {
|
||||||
|
if id == "" {
|
||||||
|
id = this.sessionId()
|
||||||
|
}
|
||||||
|
return this.p.New(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Get(id string) (*Session, error) {
|
||||||
|
return this.p.Get(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Del(id string) {
|
||||||
|
this.p.Del(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Save(id string) error {
|
||||||
|
return this.p.Save(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Count() int {
|
||||||
|
return this.p.Count()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Close() error {
|
||||||
|
return this.p.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (manager *Manager) sessionId() string {
|
||||||
|
b := make([]byte, 15)
|
||||||
|
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return base64.URLEncoding.EncodeToString(b)
|
||||||
|
}
|
||||||
@@ -0,0 +1,549 @@
|
|||||||
|
package topics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
QosAtMostOnce byte = iota
|
||||||
|
QosAtLeastOnce
|
||||||
|
QosExactlyOnce
|
||||||
|
QosFailure = 0x80
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ TopicsProvider = (*memTopics)(nil)
|
||||||
|
|
||||||
|
type memTopics struct {
|
||||||
|
// Sub/unsub mutex
|
||||||
|
smu sync.RWMutex
|
||||||
|
// Subscription tree
|
||||||
|
sroot *snode
|
||||||
|
|
||||||
|
// Retained message mutex
|
||||||
|
rmu sync.RWMutex
|
||||||
|
// Retained messages topic tree
|
||||||
|
rroot *rnode
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
Register("mem", NewMemProvider())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMemProvider returns an new instance of the memTopics, which is implements the
|
||||||
|
// TopicsProvider interface. memProvider is a hidden struct that stores the topic
|
||||||
|
// subscriptions and retained messages in memory. The content is not persistend so
|
||||||
|
// when the server goes, everything will be gone. Use with care.
|
||||||
|
func NewMemProvider() *memTopics {
|
||||||
|
return &memTopics{
|
||||||
|
sroot: newSNode(),
|
||||||
|
rroot: newRNode(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ValidQos(qos byte) bool {
|
||||||
|
return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memTopics) Subscribe(topic []byte, qos byte, sub interface{}) (byte, error) {
|
||||||
|
if !ValidQos(qos) {
|
||||||
|
return QosFailure, fmt.Errorf("Invalid QoS %d", qos)
|
||||||
|
}
|
||||||
|
|
||||||
|
if sub == nil {
|
||||||
|
return QosFailure, fmt.Errorf("Subscriber cannot be nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
this.smu.Lock()
|
||||||
|
defer this.smu.Unlock()
|
||||||
|
|
||||||
|
if qos > QosExactlyOnce {
|
||||||
|
qos = QosExactlyOnce
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := this.sroot.sinsert(topic, qos, sub); err != nil {
|
||||||
|
return QosFailure, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return qos, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memTopics) Unsubscribe(topic []byte, sub interface{}) error {
|
||||||
|
this.smu.Lock()
|
||||||
|
defer this.smu.Unlock()
|
||||||
|
|
||||||
|
return this.sroot.sremove(topic, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returned values will be invalidated by the next Subscribers call
|
||||||
|
func (this *memTopics) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
|
||||||
|
if !ValidQos(qos) {
|
||||||
|
return fmt.Errorf("Invalid QoS %d", qos)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.smu.RLock()
|
||||||
|
defer this.smu.RUnlock()
|
||||||
|
|
||||||
|
*subs = (*subs)[0:0]
|
||||||
|
*qoss = (*qoss)[0:0]
|
||||||
|
|
||||||
|
return this.sroot.smatch(topic, qos, subs, qoss)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memTopics) Retain(msg *packets.PublishPacket) error {
|
||||||
|
this.rmu.Lock()
|
||||||
|
defer this.rmu.Unlock()
|
||||||
|
|
||||||
|
// So apparently, at least according to the MQTT Conformance/Interoperability
|
||||||
|
// Testing, that a payload of 0 means delete the retain message.
|
||||||
|
// https://eclipse.org/paho/clients/testing/
|
||||||
|
if len(msg.Payload) == 0 {
|
||||||
|
return this.rroot.rremove([]byte(msg.TopicName))
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.rroot.rinsert([]byte(msg.TopicName), msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memTopics) Retained(topic []byte, msgs *[]*packets.PublishPacket) error {
|
||||||
|
this.rmu.RLock()
|
||||||
|
defer this.rmu.RUnlock()
|
||||||
|
|
||||||
|
return this.rroot.rmatch(topic, msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *memTopics) Close() error {
|
||||||
|
this.sroot = nil
|
||||||
|
this.rroot = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// subscrition nodes
|
||||||
|
type snode struct {
|
||||||
|
// If this is the end of the topic string, then add subscribers here
|
||||||
|
subs []interface{}
|
||||||
|
qos []byte
|
||||||
|
|
||||||
|
// Otherwise add the next topic level here
|
||||||
|
snodes map[string]*snode
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSNode() *snode {
|
||||||
|
return &snode{
|
||||||
|
snodes: make(map[string]*snode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *snode) sinsert(topic []byte, qos byte, sub interface{}) error {
|
||||||
|
// If there's no more topic levels, that means we are at the matching snode
|
||||||
|
// to insert the subscriber. So let's see if there's such subscriber,
|
||||||
|
// if so, update it. Otherwise insert it.
|
||||||
|
if len(topic) == 0 {
|
||||||
|
// Let's see if the subscriber is already on the list. If yes, update
|
||||||
|
// QoS and then return.
|
||||||
|
for i := range this.subs {
|
||||||
|
if equal(this.subs[i], sub) {
|
||||||
|
this.qos[i] = qos
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise add.
|
||||||
|
this.subs = append(this.subs, sub)
|
||||||
|
this.qos = append(this.qos, qos)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not the last level, so let's find or create the next level snode, and
|
||||||
|
// recursively call it's insert().
|
||||||
|
|
||||||
|
// ntl = next topic level
|
||||||
|
ntl, rem, err := nextTopicLevel(topic)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(ntl)
|
||||||
|
|
||||||
|
// Add snode if it doesn't already exist
|
||||||
|
n, ok := this.snodes[level]
|
||||||
|
if !ok {
|
||||||
|
n = newSNode()
|
||||||
|
this.snodes[level] = n
|
||||||
|
}
|
||||||
|
|
||||||
|
return n.sinsert(rem, qos, sub)
|
||||||
|
}
|
||||||
|
|
||||||
|
// This remove implementation ignores the QoS, as long as the subscriber
|
||||||
|
// matches then it's removed
|
||||||
|
func (this *snode) sremove(topic []byte, sub interface{}) error {
|
||||||
|
// If the topic is empty, it means we are at the final matching snode. If so,
|
||||||
|
// let's find the matching subscribers and remove them.
|
||||||
|
if len(topic) == 0 {
|
||||||
|
// If subscriber == nil, then it's signal to remove ALL subscribers
|
||||||
|
if sub == nil {
|
||||||
|
this.subs = this.subs[0:0]
|
||||||
|
this.qos = this.qos[0:0]
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we find the subscriber then remove it from the list. Technically
|
||||||
|
// we just overwrite the slot by shifting all other items up by one.
|
||||||
|
for i := range this.subs {
|
||||||
|
if equal(this.subs[i], sub) {
|
||||||
|
this.subs = append(this.subs[:i], this.subs[i+1:]...)
|
||||||
|
this.qos = append(this.qos[:i], this.qos[i+1:]...)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("No topic found for subscriber")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not the last level, so let's find the next level snode, and recursively
|
||||||
|
// call it's remove().
|
||||||
|
|
||||||
|
// ntl = next topic level
|
||||||
|
ntl, rem, err := nextTopicLevel(topic)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(ntl)
|
||||||
|
|
||||||
|
// Find the snode that matches the topic level
|
||||||
|
n, ok := this.snodes[level]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("No topic found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the subscriber from the next level snode
|
||||||
|
if err := n.sremove(rem, sub); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are no more subscribers and snodes to the next level we just visited
|
||||||
|
// let's remove it
|
||||||
|
if len(n.subs) == 0 && len(n.snodes) == 0 {
|
||||||
|
delete(this.snodes, level)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// smatch() returns all the subscribers that are subscribed to the topic. Given a topic
|
||||||
|
// with no wildcards (publish topic), it returns a list of subscribers that subscribes
|
||||||
|
// to the topic. For each of the level names, it's a match
|
||||||
|
// - if there are subscribers to '#', then all the subscribers are added to result set
|
||||||
|
func (this *snode) smatch(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
|
||||||
|
// If the topic is empty, it means we are at the final matching snode. If so,
|
||||||
|
// let's find the subscribers that match the qos and append them to the list.
|
||||||
|
if len(topic) == 0 {
|
||||||
|
this.matchQos(qos, subs, qoss)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ntl = next topic level
|
||||||
|
ntl, rem, err := nextTopicLevel(topic)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(ntl)
|
||||||
|
|
||||||
|
for k, n := range this.snodes {
|
||||||
|
// If the key is "#", then these subscribers are added to the result set
|
||||||
|
if k == MWC {
|
||||||
|
n.matchQos(qos, subs, qoss)
|
||||||
|
} else if k == SWC || k == level {
|
||||||
|
if err := n.smatch(rem, qos, subs, qoss); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// retained message nodes
|
||||||
|
type rnode struct {
|
||||||
|
// If this is the end of the topic string, then add retained messages here
|
||||||
|
msg *packets.PublishPacket
|
||||||
|
// Otherwise add the next topic level here
|
||||||
|
rnodes map[string]*rnode
|
||||||
|
}
|
||||||
|
|
||||||
|
func newRNode() *rnode {
|
||||||
|
return &rnode{
|
||||||
|
rnodes: make(map[string]*rnode),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *rnode) rinsert(topic []byte, msg *packets.PublishPacket) error {
|
||||||
|
// If there's no more topic levels, that means we are at the matching rnode.
|
||||||
|
if len(topic) == 0 {
|
||||||
|
// Reuse the message if possible
|
||||||
|
if this.msg == nil {
|
||||||
|
this.msg = msg
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not the last level, so let's find or create the next level snode, and
|
||||||
|
// recursively call it's insert().
|
||||||
|
|
||||||
|
// ntl = next topic level
|
||||||
|
ntl, rem, err := nextTopicLevel(topic)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(ntl)
|
||||||
|
|
||||||
|
// Add snode if it doesn't already exist
|
||||||
|
n, ok := this.rnodes[level]
|
||||||
|
if !ok {
|
||||||
|
n = newRNode()
|
||||||
|
this.rnodes[level] = n
|
||||||
|
}
|
||||||
|
|
||||||
|
return n.rinsert(rem, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the retained message for the supplied topic
|
||||||
|
func (this *rnode) rremove(topic []byte) error {
|
||||||
|
// If the topic is empty, it means we are at the final matching rnode. If so,
|
||||||
|
// let's remove the buffer and message.
|
||||||
|
if len(topic) == 0 {
|
||||||
|
this.msg = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not the last level, so let's find the next level rnode, and recursively
|
||||||
|
// call it's remove().
|
||||||
|
|
||||||
|
// ntl = next topic level
|
||||||
|
ntl, rem, err := nextTopicLevel(topic)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(ntl)
|
||||||
|
|
||||||
|
// Find the rnode that matches the topic level
|
||||||
|
n, ok := this.rnodes[level]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("No topic found")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove the subscriber from the next level rnode
|
||||||
|
if err := n.rremove(rem); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there are no more rnodes to the next level we just visited let's remove it
|
||||||
|
if len(n.rnodes) == 0 {
|
||||||
|
delete(this.rnodes, level)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rmatch() finds the retained messages for the topic and qos provided. It's somewhat
|
||||||
|
// of a reverse match compare to match() since the supplied topic can contain
|
||||||
|
// wildcards, whereas the retained message topic is a full (no wildcard) topic.
|
||||||
|
func (this *rnode) rmatch(topic []byte, msgs *[]*packets.PublishPacket) error {
|
||||||
|
// If the topic is empty, it means we are at the final matching rnode. If so,
|
||||||
|
// add the retained msg to the list.
|
||||||
|
if len(topic) == 0 {
|
||||||
|
if this.msg != nil {
|
||||||
|
*msgs = append(*msgs, this.msg)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ntl = next topic level
|
||||||
|
ntl, rem, err := nextTopicLevel(topic)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
level := string(ntl)
|
||||||
|
|
||||||
|
if level == MWC {
|
||||||
|
// If '#', add all retained messages starting this node
|
||||||
|
this.allRetained(msgs)
|
||||||
|
} else if level == SWC {
|
||||||
|
// If '+', check all nodes at this level. Next levels must be matched.
|
||||||
|
for _, n := range this.rnodes {
|
||||||
|
if err := n.rmatch(rem, msgs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Otherwise, find the matching node, go to the next level
|
||||||
|
if n, ok := this.rnodes[level]; ok {
|
||||||
|
if err := n.rmatch(rem, msgs); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *rnode) allRetained(msgs *[]*packets.PublishPacket) {
|
||||||
|
if this.msg != nil {
|
||||||
|
*msgs = append(*msgs, this.msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, n := range this.rnodes {
|
||||||
|
n.allRetained(msgs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
stateCHR byte = iota // Regular character
|
||||||
|
stateMWC // Multi-level wildcard
|
||||||
|
stateSWC // Single-level wildcard
|
||||||
|
stateSEP // Topic level separator
|
||||||
|
stateSYS // System level topic ($)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Returns topic level, remaining topic levels and any errors
|
||||||
|
func nextTopicLevel(topic []byte) ([]byte, []byte, error) {
|
||||||
|
s := stateCHR
|
||||||
|
|
||||||
|
for i, c := range topic {
|
||||||
|
switch c {
|
||||||
|
case '/':
|
||||||
|
if s == stateMWC {
|
||||||
|
return nil, nil, fmt.Errorf("Multi-level wildcard found in topic and it's not at the last level")
|
||||||
|
}
|
||||||
|
|
||||||
|
if i == 0 {
|
||||||
|
return []byte(SWC), topic[i+1:], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return topic[:i], topic[i+1:], nil
|
||||||
|
|
||||||
|
case '#':
|
||||||
|
if i != 0 {
|
||||||
|
return nil, nil, fmt.Errorf("Wildcard character '#' must occupy entire topic level")
|
||||||
|
}
|
||||||
|
|
||||||
|
s = stateMWC
|
||||||
|
|
||||||
|
case '+':
|
||||||
|
if i != 0 {
|
||||||
|
return nil, nil, fmt.Errorf("Wildcard character '+' must occupy entire topic level")
|
||||||
|
}
|
||||||
|
|
||||||
|
s = stateSWC
|
||||||
|
|
||||||
|
// case '$':
|
||||||
|
// if i == 0 {
|
||||||
|
// return nil, nil, fmt.Errorf("Cannot publish to $ topics")
|
||||||
|
// }
|
||||||
|
|
||||||
|
// s = stateSYS
|
||||||
|
|
||||||
|
default:
|
||||||
|
if s == stateMWC || s == stateSWC {
|
||||||
|
return nil, nil, fmt.Errorf("Wildcard characters '#' and '+' must occupy entire topic level")
|
||||||
|
}
|
||||||
|
|
||||||
|
s = stateCHR
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If we got here that means we didn't hit the separator along the way, so the
|
||||||
|
// topic is either empty, or does not contain a separator. Either way, we return
|
||||||
|
// the full topic
|
||||||
|
return topic, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// The QoS of the payload messages sent in response to a subscription must be the
|
||||||
|
// minimum of the QoS of the originally published message (in this case, it's the
|
||||||
|
// qos parameter) and the maximum QoS granted by the server (in this case, it's
|
||||||
|
// the QoS in the topic tree).
|
||||||
|
//
|
||||||
|
// It's also possible that even if the topic matches, the subscriber is not included
|
||||||
|
// due to the QoS granted is lower than the published message QoS. For example,
|
||||||
|
// if the client is granted only QoS 0, and the publish message is QoS 1, then this
|
||||||
|
// client is not to be send the published message.
|
||||||
|
func (this *snode) matchQos(qos byte, subs *[]interface{}, qoss *[]byte) {
|
||||||
|
for _, sub := range this.subs {
|
||||||
|
// If the published QoS is higher than the subscriber QoS, then we skip the
|
||||||
|
// subscriber. Otherwise, add to the list.
|
||||||
|
// if qos >= this.qos[i] {
|
||||||
|
*subs = append(*subs, sub)
|
||||||
|
*qoss = append(*qoss, qos)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func equal(k1, k2 interface{}) bool {
|
||||||
|
if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if reflect.ValueOf(k1).Kind() == reflect.Func {
|
||||||
|
return &k1 == &k2
|
||||||
|
}
|
||||||
|
|
||||||
|
if k1 == k2 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
switch k1 := k1.(type) {
|
||||||
|
case string:
|
||||||
|
return k1 == k2.(string)
|
||||||
|
|
||||||
|
case int64:
|
||||||
|
return k1 == k2.(int64)
|
||||||
|
|
||||||
|
case int32:
|
||||||
|
return k1 == k2.(int32)
|
||||||
|
|
||||||
|
case int16:
|
||||||
|
return k1 == k2.(int16)
|
||||||
|
|
||||||
|
case int8:
|
||||||
|
return k1 == k2.(int8)
|
||||||
|
|
||||||
|
case int:
|
||||||
|
return k1 == k2.(int)
|
||||||
|
|
||||||
|
case float32:
|
||||||
|
return k1 == k2.(float32)
|
||||||
|
|
||||||
|
case float64:
|
||||||
|
return k1 == k2.(float64)
|
||||||
|
|
||||||
|
case uint:
|
||||||
|
return k1 == k2.(uint)
|
||||||
|
|
||||||
|
case uint8:
|
||||||
|
return k1 == k2.(uint8)
|
||||||
|
|
||||||
|
case uint16:
|
||||||
|
return k1 == k2.(uint16)
|
||||||
|
|
||||||
|
case uint32:
|
||||||
|
return k1 == k2.(uint32)
|
||||||
|
|
||||||
|
case uint64:
|
||||||
|
return k1 == k2.(uint64)
|
||||||
|
|
||||||
|
case uintptr:
|
||||||
|
return k1 == k2.(uintptr)
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
@@ -0,0 +1,91 @@
|
|||||||
|
package topics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MWC is the multi-level wildcard
|
||||||
|
MWC = "#"
|
||||||
|
|
||||||
|
// SWC is the single level wildcard
|
||||||
|
SWC = "+"
|
||||||
|
|
||||||
|
// SEP is the topic level separator
|
||||||
|
SEP = "/"
|
||||||
|
|
||||||
|
// SYS is the starting character of the system level topics
|
||||||
|
SYS = "$"
|
||||||
|
|
||||||
|
// Both wildcards
|
||||||
|
_WC = "#+"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
providers = make(map[string]TopicsProvider)
|
||||||
|
)
|
||||||
|
|
||||||
|
// TopicsProvider
|
||||||
|
type TopicsProvider interface {
|
||||||
|
Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error)
|
||||||
|
Unsubscribe(topic []byte, subscriber interface{}) error
|
||||||
|
Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error
|
||||||
|
Retain(msg *packets.PublishPacket) error
|
||||||
|
Retained(topic []byte, msgs *[]*packets.PublishPacket) error
|
||||||
|
Close() error
|
||||||
|
}
|
||||||
|
|
||||||
|
func Register(name string, provider TopicsProvider) {
|
||||||
|
if provider == nil {
|
||||||
|
panic("topics: Register provide is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, dup := providers[name]; dup {
|
||||||
|
panic("topics: Register called twice for provider " + name)
|
||||||
|
}
|
||||||
|
|
||||||
|
providers[name] = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
func Unregister(name string) {
|
||||||
|
delete(providers, name)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Manager struct {
|
||||||
|
p TopicsProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewManager(providerName string) (*Manager, error) {
|
||||||
|
p, ok := providers[providerName]
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("session: unknown provider %q", providerName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Manager{p: p}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error) {
|
||||||
|
return this.p.Subscribe(topic, qos, subscriber)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Unsubscribe(topic []byte, subscriber interface{}) error {
|
||||||
|
return this.p.Unsubscribe(topic, subscriber)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
|
||||||
|
return this.p.Subscribers(topic, qos, subs, qoss)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Retain(msg *packets.PublishPacket) error {
|
||||||
|
return this.p.Retain(msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Retained(topic []byte, msgs *[]*packets.PublishPacket) error {
|
||||||
|
return this.p.Retained(topic, msgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (this *Manager) Close() error {
|
||||||
|
return this.p.Close()
|
||||||
|
}
|
||||||
@@ -0,0 +1,50 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
|
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// env can be setup at build time with Go Linker. Value could be prod or whatever else for dev env
|
||||||
|
instance *zap.Logger
|
||||||
|
logCfg zap.Config
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewDevLogger return a logger for dev builds
|
||||||
|
func NewDevLogger() (*zap.Logger, error) {
|
||||||
|
logCfg := zap.NewDevelopmentConfig()
|
||||||
|
return logCfg.Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewProdLogger return a logger for production builds
|
||||||
|
func NewProdLogger() (*zap.Logger, error) {
|
||||||
|
logCfg := zap.NewProductionConfig()
|
||||||
|
logCfg.DisableStacktrace = true
|
||||||
|
logCfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel)
|
||||||
|
return logCfg.Build()
|
||||||
|
}
|
||||||
|
|
||||||
|
func InitLogger(debug bool) {
|
||||||
|
var err error
|
||||||
|
var log *zap.Logger
|
||||||
|
if debug {
|
||||||
|
log, err = NewDevLogger()
|
||||||
|
} else {
|
||||||
|
log, err = NewProdLogger()
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
panic("Unable to create a logger.")
|
||||||
|
}
|
||||||
|
defer log.Sync()
|
||||||
|
|
||||||
|
log.Debug("Logger initialization succeeded")
|
||||||
|
instance = log.Named("hmq")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get return a *zap.Logger instance
|
||||||
|
func Get() *zap.Logger {
|
||||||
|
return instance
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
/*
|
||||||
|
Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
*/
|
||||||
|
package logger
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"go.uber.org/zap"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGet(t *testing.T) {
|
||||||
|
var l *zap.Logger
|
||||||
|
logger := Get()
|
||||||
|
|
||||||
|
assert.NotNil(t, logger)
|
||||||
|
assert.IsType(t, l, logger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewDevLogger(t *testing.T) {
|
||||||
|
logger, err := NewDevLogger()
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.True(t, logger.Core().Enabled(zap.DebugLevel))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewProdLogger(t *testing.T) {
|
||||||
|
logger, err := NewProdLogger()
|
||||||
|
|
||||||
|
assert.Nil(t, err)
|
||||||
|
assert.False(t, logger.Core().Enabled(zap.DebugLevel))
|
||||||
|
}
|
||||||
@@ -1,29 +1,35 @@
|
|||||||
|
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||||
|
|
||||||
|
Permission to use, copy, modify, and/or distribute this software for any
|
||||||
|
purpose with or without fee is hereby granted, provided that the above
|
||||||
|
copyright notice and this permission notice appear in all copies.
|
||||||
|
*/
|
||||||
package main
|
package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"hmq/broker"
|
"log"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
log "github.com/cihub/seelog"
|
"github.com/fhmq/hmq/broker"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
config, er := broker.LoadConfig()
|
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||||
if er != nil {
|
config, err := broker.ConfigureConfig(os.Args[1:])
|
||||||
log.Error("Load Config file error: ", er)
|
if err != nil {
|
||||||
return
|
log.Fatal("configure broker config error: ", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
broker, err := broker.NewBroker(config)
|
b, err := broker.NewBroker(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error("New Broker error: ", er)
|
log.Fatal("New Broker error: ", err)
|
||||||
return
|
|
||||||
}
|
}
|
||||||
broker.Start()
|
b.Start()
|
||||||
|
|
||||||
s := waitForSignal()
|
s := waitForSignal()
|
||||||
log.Infof("signal got: %v ,broker closed.", s)
|
log.Println("signal received, broker closed.", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
func waitForSignal() os.Signal {
|
func waitForSignal() os.Signal {
|
||||||
|
|||||||
@@ -0,0 +1,58 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/segmentio/fasthash/fnv1a"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WorkerPool struct {
|
||||||
|
maxWorkers int
|
||||||
|
taskQueue []chan func()
|
||||||
|
stoppedChan chan struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(maxWorkers int) *WorkerPool {
|
||||||
|
// There must be at least one worker.
|
||||||
|
if maxWorkers < 1 {
|
||||||
|
maxWorkers = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
// taskQueue is unbuffered since items are always removed immediately.
|
||||||
|
pool := &WorkerPool{
|
||||||
|
taskQueue: make([]chan func(), maxWorkers),
|
||||||
|
maxWorkers: maxWorkers,
|
||||||
|
stoppedChan: make(chan struct{}),
|
||||||
|
}
|
||||||
|
// Start the task dispatcher.
|
||||||
|
pool.dispatch()
|
||||||
|
|
||||||
|
return pool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WorkerPool) Submit(uid string, task func()) {
|
||||||
|
idx := fnv1a.HashString64(uid) % uint64(p.maxWorkers)
|
||||||
|
if task != nil {
|
||||||
|
p.taskQueue[idx] <- task
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *WorkerPool) dispatch() {
|
||||||
|
for i := 0; i < p.maxWorkers; i++ {
|
||||||
|
p.taskQueue[i] = make(chan func())
|
||||||
|
go startWorker(p.taskQueue[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func startWorker(taskChan chan func()) {
|
||||||
|
go func() {
|
||||||
|
var task func()
|
||||||
|
var ok bool
|
||||||
|
for {
|
||||||
|
task, ok = <-taskChan
|
||||||
|
if !ok {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// Execute the task.
|
||||||
|
task()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
+166
@@ -0,0 +1,166 @@
|
|||||||
|
package pool
|
||||||
|
|
||||||
|
// import "time"
|
||||||
|
|
||||||
|
// const (
|
||||||
|
// // This value is the size of the queue that workers register their
|
||||||
|
// // availability to the dispatcher. There may be hundreds of workers, but
|
||||||
|
// // only a small channel is needed to register some of the workers.
|
||||||
|
// readyQueueSize = 16
|
||||||
|
|
||||||
|
// // If worker pool receives no new work for this period of time, then stop
|
||||||
|
// // a worker goroutine.
|
||||||
|
// idleTimeoutSec = 5
|
||||||
|
// )
|
||||||
|
|
||||||
|
// type WorkerPool struct {
|
||||||
|
// maxWorkers int
|
||||||
|
// timeout time.Duration
|
||||||
|
// taskQueue chan func()
|
||||||
|
// readyWorkers chan chan func()
|
||||||
|
// stoppedChan chan struct{}
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func New(maxWorkers int) *WorkerPool {
|
||||||
|
// // There must be at least one worker.
|
||||||
|
// if maxWorkers < 1 {
|
||||||
|
// maxWorkers = 1
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // taskQueue is unbuffered since items are always removed immediately.
|
||||||
|
// pool := &WorkerPool{
|
||||||
|
// taskQueue: make(chan func()),
|
||||||
|
// maxWorkers: maxWorkers,
|
||||||
|
// readyWorkers: make(chan chan func(), readyQueueSize),
|
||||||
|
// timeout: time.Second * idleTimeoutSec,
|
||||||
|
// stoppedChan: make(chan struct{}),
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Start the task dispatcher.
|
||||||
|
// go pool.dispatch()
|
||||||
|
|
||||||
|
// return pool
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (p *WorkerPool) Stop() {
|
||||||
|
// if p.Stopped() {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// close(p.taskQueue)
|
||||||
|
// <-p.stoppedChan
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (p *WorkerPool) Stopped() bool {
|
||||||
|
// select {
|
||||||
|
// case <-p.stoppedChan:
|
||||||
|
// return true
|
||||||
|
// default:
|
||||||
|
// }
|
||||||
|
// return false
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (p *WorkerPool) Submit(task func()) {
|
||||||
|
// if task != nil {
|
||||||
|
// p.taskQueue <- task
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (p *WorkerPool) SubmitWait(task func()) {
|
||||||
|
// if task == nil {
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// doneChan := make(chan struct{})
|
||||||
|
// p.taskQueue <- func() {
|
||||||
|
// task()
|
||||||
|
// close(doneChan)
|
||||||
|
// }
|
||||||
|
// <-doneChan
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func (p *WorkerPool) dispatch() {
|
||||||
|
// defer close(p.stoppedChan)
|
||||||
|
// timeout := time.NewTimer(p.timeout)
|
||||||
|
// var workerCount int
|
||||||
|
// var task func()
|
||||||
|
// var ok bool
|
||||||
|
// var workerTaskChan chan func()
|
||||||
|
// startReady := make(chan chan func())
|
||||||
|
// Loop:
|
||||||
|
// for {
|
||||||
|
// timeout.Reset(p.timeout)
|
||||||
|
// select {
|
||||||
|
// case task, ok = <-p.taskQueue:
|
||||||
|
// if !ok {
|
||||||
|
// break Loop
|
||||||
|
// }
|
||||||
|
// // Got a task to do.
|
||||||
|
// select {
|
||||||
|
// case workerTaskChan = <-p.readyWorkers:
|
||||||
|
// // A worker is ready, so give task to worker.
|
||||||
|
// workerTaskChan <- task
|
||||||
|
// default:
|
||||||
|
// // No workers ready.
|
||||||
|
// // Create a new worker, if not at max.
|
||||||
|
// if workerCount < p.maxWorkers {
|
||||||
|
// workerCount++
|
||||||
|
// go func(t func()) {
|
||||||
|
// startWorker(startReady, p.readyWorkers)
|
||||||
|
// // Submit the task when the new worker.
|
||||||
|
// taskChan := <-startReady
|
||||||
|
// taskChan <- t
|
||||||
|
// }(task)
|
||||||
|
// } else {
|
||||||
|
// // Start a goroutine to submit the task when an existing
|
||||||
|
// // worker is ready.
|
||||||
|
// go func(t func()) {
|
||||||
|
// taskChan := <-p.readyWorkers
|
||||||
|
// taskChan <- t
|
||||||
|
// }(task)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// case <-timeout.C:
|
||||||
|
// // Timed out waiting for work to arrive. Kill a ready worker.
|
||||||
|
// if workerCount > 0 {
|
||||||
|
// select {
|
||||||
|
// case workerTaskChan = <-p.readyWorkers:
|
||||||
|
// // A worker is ready, so kill.
|
||||||
|
// close(workerTaskChan)
|
||||||
|
// workerCount--
|
||||||
|
// default:
|
||||||
|
// // No work, but no ready workers. All workers are busy.
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Stop all remaining workers as they become ready.
|
||||||
|
// for workerCount > 0 {
|
||||||
|
// workerTaskChan = <-p.readyWorkers
|
||||||
|
// close(workerTaskChan)
|
||||||
|
// workerCount--
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// func startWorker(startReady, readyWorkers chan chan func()) {
|
||||||
|
// go func() {
|
||||||
|
// taskChan := make(chan func())
|
||||||
|
// var task func()
|
||||||
|
// var ok bool
|
||||||
|
// // Register availability on starReady channel.
|
||||||
|
// startReady <- taskChan
|
||||||
|
// for {
|
||||||
|
// // Read task from dispatcher.
|
||||||
|
// task, ok = <-taskChan
|
||||||
|
// if !ok {
|
||||||
|
// // Dispatcher has told worker to stop.
|
||||||
|
// break
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Execute the task.
|
||||||
|
// task()
|
||||||
|
|
||||||
|
// // Register availability on readyWorkers channel.
|
||||||
|
// readyWorkers <- taskChan
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
// }
|
||||||
+20
@@ -0,0 +1,20 @@
|
|||||||
|
The MIT License (MIT)
|
||||||
|
|
||||||
|
Copyright (c) 2013 Stack Exchange
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||||
|
this software and associated documentation files (the "Software"), to deal in
|
||||||
|
the Software without restriction, including without limitation the rights to
|
||||||
|
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
||||||
|
the Software, and to permit persons to whom the Software is furnished to do so,
|
||||||
|
subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
|
||||||
|
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
|
||||||
|
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
|
||||||
|
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
|
||||||
|
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
||||||
+6
@@ -0,0 +1,6 @@
|
|||||||
|
wmi
|
||||||
|
===
|
||||||
|
|
||||||
|
Package wmi provides a WQL interface to Windows WMI.
|
||||||
|
|
||||||
|
Note: It interfaces with WMI on the local machine, therefore it only runs on Windows.
|
||||||
+260
@@ -0,0 +1,260 @@
|
|||||||
|
// +build windows
|
||||||
|
|
||||||
|
package wmi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/go-ole/go-ole"
|
||||||
|
"github.com/go-ole/go-ole/oleutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
// SWbemServices is used to access wmi. See https://msdn.microsoft.com/en-us/library/aa393719(v=vs.85).aspx
|
||||||
|
type SWbemServices struct {
|
||||||
|
//TODO: track namespace. Not sure if we can re connect to a different namespace using the same instance
|
||||||
|
cWMIClient *Client //This could also be an embedded struct, but then we would need to branch on Client vs SWbemServices in the Query method
|
||||||
|
sWbemLocatorIUnknown *ole.IUnknown
|
||||||
|
sWbemLocatorIDispatch *ole.IDispatch
|
||||||
|
queries chan *queryRequest
|
||||||
|
closeError chan error
|
||||||
|
lQueryorClose sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
type queryRequest struct {
|
||||||
|
query string
|
||||||
|
dst interface{}
|
||||||
|
args []interface{}
|
||||||
|
finished chan error
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitializeSWbemServices will return a new SWbemServices object that can be used to query WMI
|
||||||
|
func InitializeSWbemServices(c *Client, connectServerArgs ...interface{}) (*SWbemServices, error) {
|
||||||
|
//fmt.Println("InitializeSWbemServices: Starting")
|
||||||
|
//TODO: implement connectServerArgs as optional argument for init with connectServer call
|
||||||
|
s := new(SWbemServices)
|
||||||
|
s.cWMIClient = c
|
||||||
|
s.queries = make(chan *queryRequest)
|
||||||
|
initError := make(chan error)
|
||||||
|
go s.process(initError)
|
||||||
|
|
||||||
|
err, ok := <-initError
|
||||||
|
if ok {
|
||||||
|
return nil, err //Send error to caller
|
||||||
|
}
|
||||||
|
//fmt.Println("InitializeSWbemServices: Finished")
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close will clear and release all of the SWbemServices resources
|
||||||
|
func (s *SWbemServices) Close() error {
|
||||||
|
s.lQueryorClose.Lock()
|
||||||
|
if s == nil || s.sWbemLocatorIDispatch == nil {
|
||||||
|
s.lQueryorClose.Unlock()
|
||||||
|
return fmt.Errorf("SWbemServices is not Initialized")
|
||||||
|
}
|
||||||
|
if s.queries == nil {
|
||||||
|
s.lQueryorClose.Unlock()
|
||||||
|
return fmt.Errorf("SWbemServices has been closed")
|
||||||
|
}
|
||||||
|
//fmt.Println("Close: sending close request")
|
||||||
|
var result error
|
||||||
|
ce := make(chan error)
|
||||||
|
s.closeError = ce //Race condition if multiple callers to close. May need to lock here
|
||||||
|
close(s.queries) //Tell background to shut things down
|
||||||
|
s.lQueryorClose.Unlock()
|
||||||
|
err, ok := <-ce
|
||||||
|
if ok {
|
||||||
|
result = err
|
||||||
|
}
|
||||||
|
//fmt.Println("Close: finished")
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SWbemServices) process(initError chan error) {
|
||||||
|
//fmt.Println("process: starting background thread initialization")
|
||||||
|
//All OLE/WMI calls must happen on the same initialized thead, so lock this goroutine
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
|
||||||
|
if err != nil {
|
||||||
|
oleCode := err.(*ole.OleError).Code()
|
||||||
|
if oleCode != ole.S_OK && oleCode != S_FALSE {
|
||||||
|
initError <- fmt.Errorf("ole.CoInitializeEx error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer ole.CoUninitialize()
|
||||||
|
|
||||||
|
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
|
||||||
|
if err != nil {
|
||||||
|
initError <- fmt.Errorf("CreateObject SWbemLocator error: %v", err)
|
||||||
|
return
|
||||||
|
} else if unknown == nil {
|
||||||
|
initError <- ErrNilCreateObject
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer unknown.Release()
|
||||||
|
s.sWbemLocatorIUnknown = unknown
|
||||||
|
|
||||||
|
dispatch, err := s.sWbemLocatorIUnknown.QueryInterface(ole.IID_IDispatch)
|
||||||
|
if err != nil {
|
||||||
|
initError <- fmt.Errorf("SWbemLocator QueryInterface error: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer dispatch.Release()
|
||||||
|
s.sWbemLocatorIDispatch = dispatch
|
||||||
|
|
||||||
|
// we can't do the ConnectServer call outside the loop unless we find a way to track and re-init the connectServerArgs
|
||||||
|
//fmt.Println("process: initialized. closing initError")
|
||||||
|
close(initError)
|
||||||
|
//fmt.Println("process: waiting for queries")
|
||||||
|
for q := range s.queries {
|
||||||
|
//fmt.Printf("process: new query: len(query)=%d\n", len(q.query))
|
||||||
|
errQuery := s.queryBackground(q)
|
||||||
|
//fmt.Println("process: s.queryBackground finished")
|
||||||
|
if errQuery != nil {
|
||||||
|
q.finished <- errQuery
|
||||||
|
}
|
||||||
|
close(q.finished)
|
||||||
|
}
|
||||||
|
//fmt.Println("process: queries channel closed")
|
||||||
|
s.queries = nil //set channel to nil so we know it is closed
|
||||||
|
//TODO: I think the Release/Clear calls can panic if things are in a bad state.
|
||||||
|
//TODO: May need to recover from panics and send error to method caller instead.
|
||||||
|
close(s.closeError)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query runs the WQL query using a SWbemServices instance and appends the values to dst.
|
||||||
|
//
|
||||||
|
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
|
||||||
|
// the query must have the same name in dst. Supported types are all signed and
|
||||||
|
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
|
||||||
|
// Array types are not supported.
|
||||||
|
//
|
||||||
|
// By default, the local machine and default namespace are used. These can be
|
||||||
|
// changed using connectServerArgs. See
|
||||||
|
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
|
||||||
|
func (s *SWbemServices) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
|
||||||
|
s.lQueryorClose.Lock()
|
||||||
|
if s == nil || s.sWbemLocatorIDispatch == nil {
|
||||||
|
s.lQueryorClose.Unlock()
|
||||||
|
return fmt.Errorf("SWbemServices is not Initialized")
|
||||||
|
}
|
||||||
|
if s.queries == nil {
|
||||||
|
s.lQueryorClose.Unlock()
|
||||||
|
return fmt.Errorf("SWbemServices has been closed")
|
||||||
|
}
|
||||||
|
|
||||||
|
//fmt.Println("Query: Sending query request")
|
||||||
|
qr := queryRequest{
|
||||||
|
query: query,
|
||||||
|
dst: dst,
|
||||||
|
args: connectServerArgs,
|
||||||
|
finished: make(chan error),
|
||||||
|
}
|
||||||
|
s.queries <- &qr
|
||||||
|
s.lQueryorClose.Unlock()
|
||||||
|
err, ok := <-qr.finished
|
||||||
|
if ok {
|
||||||
|
//fmt.Println("Query: Finished with error")
|
||||||
|
return err //Send error to caller
|
||||||
|
}
|
||||||
|
//fmt.Println("Query: Finished")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *SWbemServices) queryBackground(q *queryRequest) error {
|
||||||
|
if s == nil || s.sWbemLocatorIDispatch == nil {
|
||||||
|
return fmt.Errorf("SWbemServices is not Initialized")
|
||||||
|
}
|
||||||
|
wmi := s.sWbemLocatorIDispatch //Should just rename in the code, but this will help as we break things apart
|
||||||
|
//fmt.Println("queryBackground: Starting")
|
||||||
|
|
||||||
|
dv := reflect.ValueOf(q.dst)
|
||||||
|
if dv.Kind() != reflect.Ptr || dv.IsNil() {
|
||||||
|
return ErrInvalidEntityType
|
||||||
|
}
|
||||||
|
dv = dv.Elem()
|
||||||
|
mat, elemType := checkMultiArg(dv)
|
||||||
|
if mat == multiArgTypeInvalid {
|
||||||
|
return ErrInvalidEntityType
|
||||||
|
}
|
||||||
|
|
||||||
|
// service is a SWbemServices
|
||||||
|
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", q.args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
service := serviceRaw.ToIDispatch()
|
||||||
|
defer serviceRaw.Clear()
|
||||||
|
|
||||||
|
// result is a SWBemObjectSet
|
||||||
|
resultRaw, err := oleutil.CallMethod(service, "ExecQuery", q.query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result := resultRaw.ToIDispatch()
|
||||||
|
defer resultRaw.Clear()
|
||||||
|
|
||||||
|
count, err := oleInt64(result, "Count")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
enumProperty, err := result.GetProperty("_NewEnum")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer enumProperty.Clear()
|
||||||
|
|
||||||
|
enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if enum == nil {
|
||||||
|
return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
|
||||||
|
}
|
||||||
|
defer enum.Release()
|
||||||
|
|
||||||
|
// Initialize a slice with Count capacity
|
||||||
|
dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
|
||||||
|
|
||||||
|
var errFieldMismatch error
|
||||||
|
for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := func() error {
|
||||||
|
// item is a SWbemObject, but really a Win32_Process
|
||||||
|
item := itemRaw.ToIDispatch()
|
||||||
|
defer item.Release()
|
||||||
|
|
||||||
|
ev := reflect.New(elemType)
|
||||||
|
if err = s.cWMIClient.loadEntity(ev.Interface(), item); err != nil {
|
||||||
|
if _, ok := err.(*ErrFieldMismatch); ok {
|
||||||
|
// We continue loading entities even in the face of field mismatch errors.
|
||||||
|
// If we encounter any other error, that other error is returned. Otherwise,
|
||||||
|
// an ErrFieldMismatch is returned.
|
||||||
|
errFieldMismatch = err
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mat != multiArgTypeStructPtr {
|
||||||
|
ev = ev.Elem()
|
||||||
|
}
|
||||||
|
dv.Set(reflect.Append(dv, ev))
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
//fmt.Println("queryBackground: Finished")
|
||||||
|
return errFieldMismatch
|
||||||
|
}
|
||||||
+490
@@ -0,0 +1,490 @@
|
|||||||
|
// +build windows
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package wmi provides a WQL interface for WMI on Windows.
|
||||||
|
|
||||||
|
Example code to print names of running processes:
|
||||||
|
|
||||||
|
type Win32_Process struct {
|
||||||
|
Name string
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
var dst []Win32_Process
|
||||||
|
q := wmi.CreateQuery(&dst, "")
|
||||||
|
err := wmi.Query(q, &dst)
|
||||||
|
if err != nil {
|
||||||
|
log.Fatal(err)
|
||||||
|
}
|
||||||
|
for i, v := range dst {
|
||||||
|
println(i, v.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
package wmi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"runtime"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/go-ole/go-ole"
|
||||||
|
"github.com/go-ole/go-ole/oleutil"
|
||||||
|
)
|
||||||
|
|
||||||
|
var l = log.New(os.Stdout, "", log.LstdFlags)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrInvalidEntityType = errors.New("wmi: invalid entity type")
|
||||||
|
// ErrNilCreateObject is the error returned if CreateObject returns nil even
|
||||||
|
// if the error was nil.
|
||||||
|
ErrNilCreateObject = errors.New("wmi: create object returned nil")
|
||||||
|
lock sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
// S_FALSE is returned by CoInitializeEx if it was already called on this thread.
|
||||||
|
const S_FALSE = 0x00000001
|
||||||
|
|
||||||
|
// QueryNamespace invokes Query with the given namespace on the local machine.
|
||||||
|
func QueryNamespace(query string, dst interface{}, namespace string) error {
|
||||||
|
return Query(query, dst, nil, namespace)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query runs the WQL query and appends the values to dst.
|
||||||
|
//
|
||||||
|
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
|
||||||
|
// the query must have the same name in dst. Supported types are all signed and
|
||||||
|
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
|
||||||
|
// Array types are not supported.
|
||||||
|
//
|
||||||
|
// By default, the local machine and default namespace are used. These can be
|
||||||
|
// changed using connectServerArgs. See
|
||||||
|
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
|
||||||
|
//
|
||||||
|
// Query is a wrapper around DefaultClient.Query.
|
||||||
|
func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
|
||||||
|
if DefaultClient.SWbemServicesClient == nil {
|
||||||
|
return DefaultClient.Query(query, dst, connectServerArgs...)
|
||||||
|
}
|
||||||
|
return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// A Client is an WMI query client.
|
||||||
|
//
|
||||||
|
// Its zero value (DefaultClient) is a usable client.
|
||||||
|
type Client struct {
|
||||||
|
// NonePtrZero specifies if nil values for fields which aren't pointers
|
||||||
|
// should be returned as the field types zero value.
|
||||||
|
//
|
||||||
|
// Setting this to true allows stucts without pointer fields to be used
|
||||||
|
// without the risk failure should a nil value returned from WMI.
|
||||||
|
NonePtrZero bool
|
||||||
|
|
||||||
|
// PtrNil specifies if nil values for pointer fields should be returned
|
||||||
|
// as nil.
|
||||||
|
//
|
||||||
|
// Setting this to true will set pointer fields to nil where WMI
|
||||||
|
// returned nil, otherwise the types zero value will be returned.
|
||||||
|
PtrNil bool
|
||||||
|
|
||||||
|
// AllowMissingFields specifies that struct fields not present in the
|
||||||
|
// query result should not result in an error.
|
||||||
|
//
|
||||||
|
// Setting this to true allows custom queries to be used with full
|
||||||
|
// struct definitions instead of having to define multiple structs.
|
||||||
|
AllowMissingFields bool
|
||||||
|
|
||||||
|
// SWbemServiceClient is an optional SWbemServices object that can be
|
||||||
|
// initialized and then reused across multiple queries. If it is null
|
||||||
|
// then the method will initialize a new temporary client each time.
|
||||||
|
SWbemServicesClient *SWbemServices
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultClient is the default Client and is used by Query, QueryNamespace
|
||||||
|
var DefaultClient = &Client{}
|
||||||
|
|
||||||
|
// Query runs the WQL query and appends the values to dst.
|
||||||
|
//
|
||||||
|
// dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
|
||||||
|
// the query must have the same name in dst. Supported types are all signed and
|
||||||
|
// unsigned integers, time.Time, string, bool, or a pointer to one of those.
|
||||||
|
// Array types are not supported.
|
||||||
|
//
|
||||||
|
// By default, the local machine and default namespace are used. These can be
|
||||||
|
// changed using connectServerArgs. See
|
||||||
|
// http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
|
||||||
|
func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
|
||||||
|
dv := reflect.ValueOf(dst)
|
||||||
|
if dv.Kind() != reflect.Ptr || dv.IsNil() {
|
||||||
|
return ErrInvalidEntityType
|
||||||
|
}
|
||||||
|
dv = dv.Elem()
|
||||||
|
mat, elemType := checkMultiArg(dv)
|
||||||
|
if mat == multiArgTypeInvalid {
|
||||||
|
return ErrInvalidEntityType
|
||||||
|
}
|
||||||
|
|
||||||
|
lock.Lock()
|
||||||
|
defer lock.Unlock()
|
||||||
|
runtime.LockOSThread()
|
||||||
|
defer runtime.UnlockOSThread()
|
||||||
|
|
||||||
|
err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
|
||||||
|
if err != nil {
|
||||||
|
oleCode := err.(*ole.OleError).Code()
|
||||||
|
if oleCode != ole.S_OK && oleCode != S_FALSE {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
defer ole.CoUninitialize()
|
||||||
|
|
||||||
|
unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
} else if unknown == nil {
|
||||||
|
return ErrNilCreateObject
|
||||||
|
}
|
||||||
|
defer unknown.Release()
|
||||||
|
|
||||||
|
wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer wmi.Release()
|
||||||
|
|
||||||
|
// service is a SWbemServices
|
||||||
|
serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
service := serviceRaw.ToIDispatch()
|
||||||
|
defer serviceRaw.Clear()
|
||||||
|
|
||||||
|
// result is a SWBemObjectSet
|
||||||
|
resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
result := resultRaw.ToIDispatch()
|
||||||
|
defer resultRaw.Clear()
|
||||||
|
|
||||||
|
count, err := oleInt64(result, "Count")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
enumProperty, err := result.GetProperty("_NewEnum")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer enumProperty.Clear()
|
||||||
|
|
||||||
|
enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if enum == nil {
|
||||||
|
return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
|
||||||
|
}
|
||||||
|
defer enum.Release()
|
||||||
|
|
||||||
|
// Initialize a slice with Count capacity
|
||||||
|
dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
|
||||||
|
|
||||||
|
var errFieldMismatch error
|
||||||
|
for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
err := func() error {
|
||||||
|
// item is a SWbemObject, but really a Win32_Process
|
||||||
|
item := itemRaw.ToIDispatch()
|
||||||
|
defer item.Release()
|
||||||
|
|
||||||
|
ev := reflect.New(elemType)
|
||||||
|
if err = c.loadEntity(ev.Interface(), item); err != nil {
|
||||||
|
if _, ok := err.(*ErrFieldMismatch); ok {
|
||||||
|
// We continue loading entities even in the face of field mismatch errors.
|
||||||
|
// If we encounter any other error, that other error is returned. Otherwise,
|
||||||
|
// an ErrFieldMismatch is returned.
|
||||||
|
errFieldMismatch = err
|
||||||
|
} else {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if mat != multiArgTypeStructPtr {
|
||||||
|
ev = ev.Elem()
|
||||||
|
}
|
||||||
|
dv.Set(reflect.Append(dv, ev))
|
||||||
|
return nil
|
||||||
|
}()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errFieldMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
// ErrFieldMismatch is returned when a field is to be loaded into a different
|
||||||
|
// type than the one it was stored from, or when a field is missing or
|
||||||
|
// unexported in the destination struct.
|
||||||
|
// StructType is the type of the struct pointed to by the destination argument.
|
||||||
|
type ErrFieldMismatch struct {
|
||||||
|
StructType reflect.Type
|
||||||
|
FieldName string
|
||||||
|
Reason string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ErrFieldMismatch) Error() string {
|
||||||
|
return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
|
||||||
|
e.FieldName, e.StructType, e.Reason)
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeType = reflect.TypeOf(time.Time{})
|
||||||
|
|
||||||
|
// loadEntity loads a SWbemObject into a struct pointer.
|
||||||
|
func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
|
||||||
|
v := reflect.ValueOf(dst).Elem()
|
||||||
|
for i := 0; i < v.NumField(); i++ {
|
||||||
|
f := v.Field(i)
|
||||||
|
of := f
|
||||||
|
isPtr := f.Kind() == reflect.Ptr
|
||||||
|
if isPtr {
|
||||||
|
ptr := reflect.New(f.Type().Elem())
|
||||||
|
f.Set(ptr)
|
||||||
|
f = f.Elem()
|
||||||
|
}
|
||||||
|
n := v.Type().Field(i).Name
|
||||||
|
if !f.CanSet() {
|
||||||
|
return &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: "CanSet() is false",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
prop, err := oleutil.GetProperty(src, n)
|
||||||
|
if err != nil {
|
||||||
|
if !c.AllowMissingFields {
|
||||||
|
errFieldMismatch = &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: "no such struct field",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer prop.Clear()
|
||||||
|
|
||||||
|
if prop.VT == 0x1 { //VT_NULL
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
switch val := prop.Value().(type) {
|
||||||
|
case int8, int16, int32, int64, int:
|
||||||
|
v := reflect.ValueOf(val).Int()
|
||||||
|
switch f.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
f.SetInt(v)
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
f.SetUint(uint64(v))
|
||||||
|
default:
|
||||||
|
return &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: "not an integer class",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case uint8, uint16, uint32, uint64:
|
||||||
|
v := reflect.ValueOf(val).Uint()
|
||||||
|
switch f.Kind() {
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
f.SetInt(int64(v))
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
f.SetUint(v)
|
||||||
|
default:
|
||||||
|
return &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: "not an integer class",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
switch f.Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
f.SetString(val)
|
||||||
|
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
|
||||||
|
iv, err := strconv.ParseInt(val, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.SetInt(iv)
|
||||||
|
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||||
|
uv, err := strconv.ParseUint(val, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.SetUint(uv)
|
||||||
|
case reflect.Struct:
|
||||||
|
switch f.Type() {
|
||||||
|
case timeType:
|
||||||
|
if len(val) == 25 {
|
||||||
|
mins, err := strconv.Atoi(val[22:])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
|
||||||
|
}
|
||||||
|
t, err := time.Parse("20060102150405.000000-0700", val)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.Set(reflect.ValueOf(t))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case bool:
|
||||||
|
switch f.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
f.SetBool(val)
|
||||||
|
default:
|
||||||
|
return &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: "not a bool",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case float32:
|
||||||
|
switch f.Kind() {
|
||||||
|
case reflect.Float32:
|
||||||
|
f.SetFloat(float64(val))
|
||||||
|
default:
|
||||||
|
return &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: "not a Float32",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
if f.Kind() == reflect.Slice {
|
||||||
|
switch f.Type().Elem().Kind() {
|
||||||
|
case reflect.String:
|
||||||
|
safeArray := prop.ToArray()
|
||||||
|
if safeArray != nil {
|
||||||
|
arr := safeArray.ToValueArray()
|
||||||
|
fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
|
||||||
|
for i, v := range arr {
|
||||||
|
s := fArr.Index(i)
|
||||||
|
s.SetString(v.(string))
|
||||||
|
}
|
||||||
|
f.Set(fArr)
|
||||||
|
}
|
||||||
|
case reflect.Uint8:
|
||||||
|
safeArray := prop.ToArray()
|
||||||
|
if safeArray != nil {
|
||||||
|
arr := safeArray.ToValueArray()
|
||||||
|
fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
|
||||||
|
for i, v := range arr {
|
||||||
|
s := fArr.Index(i)
|
||||||
|
s.SetUint(reflect.ValueOf(v).Uint())
|
||||||
|
}
|
||||||
|
f.Set(fArr)
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
return &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: fmt.Sprintf("unsupported slice type (%T)", val),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
typeof := reflect.TypeOf(val)
|
||||||
|
if typeof == nil && (isPtr || c.NonePtrZero) {
|
||||||
|
if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) {
|
||||||
|
of.Set(reflect.Zero(of.Type()))
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return &ErrFieldMismatch{
|
||||||
|
StructType: of.Type(),
|
||||||
|
FieldName: n,
|
||||||
|
Reason: fmt.Sprintf("unsupported type (%T)", val),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return errFieldMismatch
|
||||||
|
}
|
||||||
|
|
||||||
|
type multiArgType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
multiArgTypeInvalid multiArgType = iota
|
||||||
|
multiArgTypeStruct
|
||||||
|
multiArgTypeStructPtr
|
||||||
|
)
|
||||||
|
|
||||||
|
// checkMultiArg checks that v has type []S, []*S for some struct type S.
|
||||||
|
//
|
||||||
|
// It returns what category the slice's elements are, and the reflect.Type
|
||||||
|
// that represents S.
|
||||||
|
func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
|
||||||
|
if v.Kind() != reflect.Slice {
|
||||||
|
return multiArgTypeInvalid, nil
|
||||||
|
}
|
||||||
|
elemType = v.Type().Elem()
|
||||||
|
switch elemType.Kind() {
|
||||||
|
case reflect.Struct:
|
||||||
|
return multiArgTypeStruct, elemType
|
||||||
|
case reflect.Ptr:
|
||||||
|
elemType = elemType.Elem()
|
||||||
|
if elemType.Kind() == reflect.Struct {
|
||||||
|
return multiArgTypeStructPtr, elemType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return multiArgTypeInvalid, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
|
||||||
|
v, err := oleutil.GetProperty(item, prop)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer v.Clear()
|
||||||
|
|
||||||
|
i := int64(v.Val)
|
||||||
|
return i, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateQuery returns a WQL query string that queries all columns of src. where
|
||||||
|
// is an optional string that is appended to the query, to be used with WHERE
|
||||||
|
// clauses. In such a case, the "WHERE" string should appear at the beginning.
|
||||||
|
func CreateQuery(src interface{}, where string) string {
|
||||||
|
var b bytes.Buffer
|
||||||
|
b.WriteString("SELECT ")
|
||||||
|
s := reflect.Indirect(reflect.ValueOf(src))
|
||||||
|
t := s.Type()
|
||||||
|
if s.Kind() == reflect.Slice {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
if t.Kind() != reflect.Struct {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
var fields []string
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
fields = append(fields, t.Field(i).Name)
|
||||||
|
}
|
||||||
|
b.WriteString(strings.Join(fields, ", "))
|
||||||
|
b.WriteString(" FROM ")
|
||||||
|
b.WriteString(t.Name())
|
||||||
|
b.WriteString(" " + where)
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
+10
@@ -0,0 +1,10 @@
|
|||||||
|
language: go
|
||||||
|
go:
|
||||||
|
- 1.0.3
|
||||||
|
- 1.1.2
|
||||||
|
- 1.2
|
||||||
|
- tip
|
||||||
|
install:
|
||||||
|
- go get github.com/bmizerany/assert
|
||||||
|
notifications:
|
||||||
|
email: false
|
||||||
+17
@@ -0,0 +1,17 @@
|
|||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in
|
||||||
|
all copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
|
THE SOFTWARE.
|
||||||
+13
@@ -0,0 +1,13 @@
|
|||||||
|
### go-simplejson
|
||||||
|
|
||||||
|
a Go package to interact with arbitrary JSON
|
||||||
|
|
||||||
|
[](http://travis-ci.org/bitly/go-simplejson)
|
||||||
|
|
||||||
|
### Importing
|
||||||
|
|
||||||
|
import github.com/bitly/go-simplejson
|
||||||
|
|
||||||
|
### Documentation
|
||||||
|
|
||||||
|
Visit the docs on [gopkgdoc](http://godoc.org/github.com/bitly/go-simplejson)
|
||||||
+446
@@ -0,0 +1,446 @@
|
|||||||
|
package simplejson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// returns the current implementation version
|
||||||
|
func Version() string {
|
||||||
|
return "0.5.0"
|
||||||
|
}
|
||||||
|
|
||||||
|
type Json struct {
|
||||||
|
data interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewJson returns a pointer to a new `Json` object
|
||||||
|
// after unmarshaling `body` bytes
|
||||||
|
func NewJson(body []byte) (*Json, error) {
|
||||||
|
j := new(Json)
|
||||||
|
err := j.UnmarshalJSON(body)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return j, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// New returns a pointer to a new, empty `Json` object
|
||||||
|
func New() *Json {
|
||||||
|
return &Json{
|
||||||
|
data: make(map[string]interface{}),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Interface returns the underlying data
|
||||||
|
func (j *Json) Interface() interface{} {
|
||||||
|
return j.data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Encode returns its marshaled data as `[]byte`
|
||||||
|
func (j *Json) Encode() ([]byte, error) {
|
||||||
|
return j.MarshalJSON()
|
||||||
|
}
|
||||||
|
|
||||||
|
// EncodePretty returns its marshaled data as `[]byte` with indentation
|
||||||
|
func (j *Json) EncodePretty() ([]byte, error) {
|
||||||
|
return json.MarshalIndent(&j.data, "", " ")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the json.Marshaler interface.
|
||||||
|
func (j *Json) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(&j.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set modifies `Json` map by `key` and `value`
|
||||||
|
// Useful for changing single key/value in a `Json` object easily.
|
||||||
|
func (j *Json) Set(key string, val interface{}) {
|
||||||
|
m, err := j.Map()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m[key] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetPath modifies `Json`, recursively checking/creating map keys for the supplied path,
|
||||||
|
// and then finally writing in the value
|
||||||
|
func (j *Json) SetPath(branch []string, val interface{}) {
|
||||||
|
if len(branch) == 0 {
|
||||||
|
j.data = val
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// in order to insert our branch, we need map[string]interface{}
|
||||||
|
if _, ok := (j.data).(map[string]interface{}); !ok {
|
||||||
|
// have to replace with something suitable
|
||||||
|
j.data = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
curr := j.data.(map[string]interface{})
|
||||||
|
|
||||||
|
for i := 0; i < len(branch)-1; i++ {
|
||||||
|
b := branch[i]
|
||||||
|
// key exists?
|
||||||
|
if _, ok := curr[b]; !ok {
|
||||||
|
n := make(map[string]interface{})
|
||||||
|
curr[b] = n
|
||||||
|
curr = n
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// make sure the value is the right sort of thing
|
||||||
|
if _, ok := curr[b].(map[string]interface{}); !ok {
|
||||||
|
// have to replace with something suitable
|
||||||
|
n := make(map[string]interface{})
|
||||||
|
curr[b] = n
|
||||||
|
}
|
||||||
|
|
||||||
|
curr = curr[b].(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// add remaining k/v
|
||||||
|
curr[branch[len(branch)-1]] = val
|
||||||
|
}
|
||||||
|
|
||||||
|
// Del modifies `Json` map by deleting `key` if it is present.
|
||||||
|
func (j *Json) Del(key string) {
|
||||||
|
m, err := j.Map()
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
delete(m, key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get returns a pointer to a new `Json` object
|
||||||
|
// for `key` in its `map` representation
|
||||||
|
//
|
||||||
|
// useful for chaining operations (to traverse a nested JSON):
|
||||||
|
// js.Get("top_level").Get("dict").Get("value").Int()
|
||||||
|
func (j *Json) Get(key string) *Json {
|
||||||
|
m, err := j.Map()
|
||||||
|
if err == nil {
|
||||||
|
if val, ok := m[key]; ok {
|
||||||
|
return &Json{val}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Json{nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetPath searches for the item as specified by the branch
|
||||||
|
// without the need to deep dive using Get()'s.
|
||||||
|
//
|
||||||
|
// js.GetPath("top_level", "dict")
|
||||||
|
func (j *Json) GetPath(branch ...string) *Json {
|
||||||
|
jin := j
|
||||||
|
for _, p := range branch {
|
||||||
|
jin = jin.Get(p)
|
||||||
|
}
|
||||||
|
return jin
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetIndex returns a pointer to a new `Json` object
|
||||||
|
// for `index` in its `array` representation
|
||||||
|
//
|
||||||
|
// this is the analog to Get when accessing elements of
|
||||||
|
// a json array instead of a json object:
|
||||||
|
// js.Get("top_level").Get("array").GetIndex(1).Get("key").Int()
|
||||||
|
func (j *Json) GetIndex(index int) *Json {
|
||||||
|
a, err := j.Array()
|
||||||
|
if err == nil {
|
||||||
|
if len(a) > index {
|
||||||
|
return &Json{a[index]}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &Json{nil}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CheckGet returns a pointer to a new `Json` object and
|
||||||
|
// a `bool` identifying success or failure
|
||||||
|
//
|
||||||
|
// useful for chained operations when success is important:
|
||||||
|
// if data, ok := js.Get("top_level").CheckGet("inner"); ok {
|
||||||
|
// log.Println(data)
|
||||||
|
// }
|
||||||
|
func (j *Json) CheckGet(key string) (*Json, bool) {
|
||||||
|
m, err := j.Map()
|
||||||
|
if err == nil {
|
||||||
|
if val, ok := m[key]; ok {
|
||||||
|
return &Json{val}, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map type asserts to `map`
|
||||||
|
func (j *Json) Map() (map[string]interface{}, error) {
|
||||||
|
if m, ok := (j.data).(map[string]interface{}); ok {
|
||||||
|
return m, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("type assertion to map[string]interface{} failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Array type asserts to an `array`
|
||||||
|
func (j *Json) Array() ([]interface{}, error) {
|
||||||
|
if a, ok := (j.data).([]interface{}); ok {
|
||||||
|
return a, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("type assertion to []interface{} failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bool type asserts to `bool`
|
||||||
|
func (j *Json) Bool() (bool, error) {
|
||||||
|
if s, ok := (j.data).(bool); ok {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
return false, errors.New("type assertion to bool failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// String type asserts to `string`
|
||||||
|
func (j *Json) String() (string, error) {
|
||||||
|
if s, ok := (j.data).(string); ok {
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
return "", errors.New("type assertion to string failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Bytes type asserts to `[]byte`
|
||||||
|
func (j *Json) Bytes() ([]byte, error) {
|
||||||
|
if s, ok := (j.data).(string); ok {
|
||||||
|
return []byte(s), nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("type assertion to []byte failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
// StringArray type asserts to an `array` of `string`
|
||||||
|
func (j *Json) StringArray() ([]string, error) {
|
||||||
|
arr, err := j.Array()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
retArr := make([]string, 0, len(arr))
|
||||||
|
for _, a := range arr {
|
||||||
|
if a == nil {
|
||||||
|
retArr = append(retArr, "")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
s, ok := a.(string)
|
||||||
|
if !ok {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
retArr = append(retArr, s)
|
||||||
|
}
|
||||||
|
return retArr, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustArray guarantees the return of a `[]interface{}` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you want to interate over array values in a succinct manner:
|
||||||
|
// for i, v := range js.Get("results").MustArray() {
|
||||||
|
// fmt.Println(i, v)
|
||||||
|
// }
|
||||||
|
func (j *Json) MustArray(args ...[]interface{}) []interface{} {
|
||||||
|
var def []interface{}
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustArray() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := j.Array()
|
||||||
|
if err == nil {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustMap guarantees the return of a `map[string]interface{}` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you want to interate over map values in a succinct manner:
|
||||||
|
// for k, v := range js.Get("dictionary").MustMap() {
|
||||||
|
// fmt.Println(k, v)
|
||||||
|
// }
|
||||||
|
func (j *Json) MustMap(args ...map[string]interface{}) map[string]interface{} {
|
||||||
|
var def map[string]interface{}
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustMap() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := j.Map()
|
||||||
|
if err == nil {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustString guarantees the return of a `string` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you explicitly want a `string` in a single value return context:
|
||||||
|
// myFunc(js.Get("param1").MustString(), js.Get("optional_param").MustString("my_default"))
|
||||||
|
func (j *Json) MustString(args ...string) string {
|
||||||
|
var def string
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustString() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
s, err := j.String()
|
||||||
|
if err == nil {
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustStringArray guarantees the return of a `[]string` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you want to interate over array values in a succinct manner:
|
||||||
|
// for i, s := range js.Get("results").MustStringArray() {
|
||||||
|
// fmt.Println(i, s)
|
||||||
|
// }
|
||||||
|
func (j *Json) MustStringArray(args ...[]string) []string {
|
||||||
|
var def []string
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustStringArray() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
a, err := j.StringArray()
|
||||||
|
if err == nil {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt guarantees the return of an `int` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you explicitly want an `int` in a single value return context:
|
||||||
|
// myFunc(js.Get("param1").MustInt(), js.Get("optional_param").MustInt(5150))
|
||||||
|
func (j *Json) MustInt(args ...int) int {
|
||||||
|
var def int
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustInt() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := j.Int()
|
||||||
|
if err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustFloat64 guarantees the return of a `float64` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you explicitly want a `float64` in a single value return context:
|
||||||
|
// myFunc(js.Get("param1").MustFloat64(), js.Get("optional_param").MustFloat64(5.150))
|
||||||
|
func (j *Json) MustFloat64(args ...float64) float64 {
|
||||||
|
var def float64
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustFloat64() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := j.Float64()
|
||||||
|
if err == nil {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustBool guarantees the return of a `bool` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you explicitly want a `bool` in a single value return context:
|
||||||
|
// myFunc(js.Get("param1").MustBool(), js.Get("optional_param").MustBool(true))
|
||||||
|
func (j *Json) MustBool(args ...bool) bool {
|
||||||
|
var def bool
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustBool() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
b, err := j.Bool()
|
||||||
|
if err == nil {
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustInt64 guarantees the return of an `int64` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you explicitly want an `int64` in a single value return context:
|
||||||
|
// myFunc(js.Get("param1").MustInt64(), js.Get("optional_param").MustInt64(5150))
|
||||||
|
func (j *Json) MustInt64(args ...int64) int64 {
|
||||||
|
var def int64
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustInt64() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := j.Int64()
|
||||||
|
if err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
|
|
||||||
|
// MustUInt64 guarantees the return of an `uint64` (with optional default)
|
||||||
|
//
|
||||||
|
// useful when you explicitly want an `uint64` in a single value return context:
|
||||||
|
// myFunc(js.Get("param1").MustUint64(), js.Get("optional_param").MustUint64(5150))
|
||||||
|
func (j *Json) MustUint64(args ...uint64) uint64 {
|
||||||
|
var def uint64
|
||||||
|
|
||||||
|
switch len(args) {
|
||||||
|
case 0:
|
||||||
|
case 1:
|
||||||
|
def = args[0]
|
||||||
|
default:
|
||||||
|
log.Panicf("MustUint64() received too many arguments %d", len(args))
|
||||||
|
}
|
||||||
|
|
||||||
|
i, err := j.Uint64()
|
||||||
|
if err == nil {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
return def
|
||||||
|
}
|
||||||
+75
@@ -0,0 +1,75 @@
|
|||||||
|
// +build !go1.1
|
||||||
|
|
||||||
|
package simplejson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewFromReader returns a *Json by decoding from an io.Reader
|
||||||
|
func NewFromReader(r io.Reader) (*Json, error) {
|
||||||
|
j := new(Json)
|
||||||
|
dec := json.NewDecoder(r)
|
||||||
|
err := dec.Decode(&j.data)
|
||||||
|
return j, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implements the json.Unmarshaler interface.
|
||||||
|
func (j *Json) UnmarshalJSON(p []byte) error {
|
||||||
|
return json.Unmarshal(p, &j.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float64 coerces into a float64
|
||||||
|
func (j *Json) Float64() (float64, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case float32, float64:
|
||||||
|
return reflect.ValueOf(j.data).Float(), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return float64(reflect.ValueOf(j.data).Int()), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return float64(reflect.ValueOf(j.data).Uint()), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int coerces into an int
|
||||||
|
func (j *Json) Int() (int, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case float32, float64:
|
||||||
|
return int(reflect.ValueOf(j.data).Float()), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return int(reflect.ValueOf(j.data).Int()), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return int(reflect.ValueOf(j.data).Uint()), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int64 coerces into an int64
|
||||||
|
func (j *Json) Int64() (int64, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case float32, float64:
|
||||||
|
return int64(reflect.ValueOf(j.data).Float()), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return reflect.ValueOf(j.data).Int(), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return int64(reflect.ValueOf(j.data).Uint()), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uint64 coerces into an uint64
|
||||||
|
func (j *Json) Uint64() (uint64, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case float32, float64:
|
||||||
|
return uint64(reflect.ValueOf(j.data).Float()), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return uint64(reflect.ValueOf(j.data).Int()), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return reflect.ValueOf(j.data).Uint(), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
+89
@@ -0,0 +1,89 @@
|
|||||||
|
// +build go1.1
|
||||||
|
|
||||||
|
package simplejson
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Implements the json.Unmarshaler interface.
|
||||||
|
func (j *Json) UnmarshalJSON(p []byte) error {
|
||||||
|
dec := json.NewDecoder(bytes.NewBuffer(p))
|
||||||
|
dec.UseNumber()
|
||||||
|
return dec.Decode(&j.data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromReader returns a *Json by decoding from an io.Reader
|
||||||
|
func NewFromReader(r io.Reader) (*Json, error) {
|
||||||
|
j := new(Json)
|
||||||
|
dec := json.NewDecoder(r)
|
||||||
|
dec.UseNumber()
|
||||||
|
err := dec.Decode(&j.data)
|
||||||
|
return j, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Float64 coerces into a float64
|
||||||
|
func (j *Json) Float64() (float64, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case json.Number:
|
||||||
|
return j.data.(json.Number).Float64()
|
||||||
|
case float32, float64:
|
||||||
|
return reflect.ValueOf(j.data).Float(), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return float64(reflect.ValueOf(j.data).Int()), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return float64(reflect.ValueOf(j.data).Uint()), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int coerces into an int
|
||||||
|
func (j *Json) Int() (int, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case json.Number:
|
||||||
|
i, err := j.data.(json.Number).Int64()
|
||||||
|
return int(i), err
|
||||||
|
case float32, float64:
|
||||||
|
return int(reflect.ValueOf(j.data).Float()), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return int(reflect.ValueOf(j.data).Int()), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return int(reflect.ValueOf(j.data).Uint()), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Int64 coerces into an int64
|
||||||
|
func (j *Json) Int64() (int64, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case json.Number:
|
||||||
|
return j.data.(json.Number).Int64()
|
||||||
|
case float32, float64:
|
||||||
|
return int64(reflect.ValueOf(j.data).Float()), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return reflect.ValueOf(j.data).Int(), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return int64(reflect.ValueOf(j.data).Uint()), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uint64 coerces into an uint64
|
||||||
|
func (j *Json) Uint64() (uint64, error) {
|
||||||
|
switch j.data.(type) {
|
||||||
|
case json.Number:
|
||||||
|
return strconv.ParseUint(j.data.(json.Number).String(), 10, 64)
|
||||||
|
case float32, float64:
|
||||||
|
return uint64(reflect.ValueOf(j.data).Float()), nil
|
||||||
|
case int, int8, int16, int32, int64:
|
||||||
|
return uint64(reflect.ValueOf(j.data).Int()), nil
|
||||||
|
case uint, uint8, uint16, uint32, uint64:
|
||||||
|
return reflect.ValueOf(j.data).Uint(), nil
|
||||||
|
}
|
||||||
|
return 0, errors.New("invalid value type")
|
||||||
|
}
|
||||||
+15
@@ -0,0 +1,15 @@
|
|||||||
|
ISC License
|
||||||
|
|
||||||
|
Copyright (c) 2012-2016 Dave Collins <dave@davec.name>
|
||||||
|
|
||||||
|
Permission to use, copy, modify, and distribute this software for any
|
||||||
|
purpose with or without fee is hereby granted, provided that the above
|
||||||
|
copyright notice and this permission notice appear in all copies.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
+152
@@ -0,0 +1,152 @@
|
|||||||
|
// Copyright (c) 2015-2016 Dave Collins <dave@davec.name>
|
||||||
|
//
|
||||||
|
// Permission to use, copy, modify, and distribute this software for any
|
||||||
|
// purpose with or without fee is hereby granted, provided that the above
|
||||||
|
// copyright notice and this permission notice appear in all copies.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||||
|
// when the code is not running on Google App Engine, compiled by GopherJS, and
|
||||||
|
// "-tags safe" is not added to the go build command line. The "disableunsafe"
|
||||||
|
// tag is deprecated and thus should not be used.
|
||||||
|
// +build !js,!appengine,!safe,!disableunsafe
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UnsafeDisabled is a build-time constant which specifies whether or
|
||||||
|
// not access to the unsafe package is available.
|
||||||
|
UnsafeDisabled = false
|
||||||
|
|
||||||
|
// ptrSize is the size of a pointer on the current arch.
|
||||||
|
ptrSize = unsafe.Sizeof((*byte)(nil))
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// offsetPtr, offsetScalar, and offsetFlag are the offsets for the
|
||||||
|
// internal reflect.Value fields. These values are valid before golang
|
||||||
|
// commit ecccf07e7f9d which changed the format. The are also valid
|
||||||
|
// after commit 82f48826c6c7 which changed the format again to mirror
|
||||||
|
// the original format. Code in the init function updates these offsets
|
||||||
|
// as necessary.
|
||||||
|
offsetPtr = uintptr(ptrSize)
|
||||||
|
offsetScalar = uintptr(0)
|
||||||
|
offsetFlag = uintptr(ptrSize * 2)
|
||||||
|
|
||||||
|
// flagKindWidth and flagKindShift indicate various bits that the
|
||||||
|
// reflect package uses internally to track kind information.
|
||||||
|
//
|
||||||
|
// flagRO indicates whether or not the value field of a reflect.Value is
|
||||||
|
// read-only.
|
||||||
|
//
|
||||||
|
// flagIndir indicates whether the value field of a reflect.Value is
|
||||||
|
// the actual data or a pointer to the data.
|
||||||
|
//
|
||||||
|
// These values are valid before golang commit 90a7c3c86944 which
|
||||||
|
// changed their positions. Code in the init function updates these
|
||||||
|
// flags as necessary.
|
||||||
|
flagKindWidth = uintptr(5)
|
||||||
|
flagKindShift = uintptr(flagKindWidth - 1)
|
||||||
|
flagRO = uintptr(1 << 0)
|
||||||
|
flagIndir = uintptr(1 << 1)
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Older versions of reflect.Value stored small integers directly in the
|
||||||
|
// ptr field (which is named val in the older versions). Versions
|
||||||
|
// between commits ecccf07e7f9d and 82f48826c6c7 added a new field named
|
||||||
|
// scalar for this purpose which unfortunately came before the flag
|
||||||
|
// field, so the offset of the flag field is different for those
|
||||||
|
// versions.
|
||||||
|
//
|
||||||
|
// This code constructs a new reflect.Value from a known small integer
|
||||||
|
// and checks if the size of the reflect.Value struct indicates it has
|
||||||
|
// the scalar field. When it does, the offsets are updated accordingly.
|
||||||
|
vv := reflect.ValueOf(0xf00)
|
||||||
|
if unsafe.Sizeof(vv) == (ptrSize * 4) {
|
||||||
|
offsetScalar = ptrSize * 2
|
||||||
|
offsetFlag = ptrSize * 3
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit 90a7c3c86944 changed the flag positions such that the low
|
||||||
|
// order bits are the kind. This code extracts the kind from the flags
|
||||||
|
// field and ensures it's the correct type. When it's not, the flag
|
||||||
|
// order has been changed to the newer format, so the flags are updated
|
||||||
|
// accordingly.
|
||||||
|
upf := unsafe.Pointer(uintptr(unsafe.Pointer(&vv)) + offsetFlag)
|
||||||
|
upfv := *(*uintptr)(upf)
|
||||||
|
flagKindMask := uintptr((1<<flagKindWidth - 1) << flagKindShift)
|
||||||
|
if (upfv&flagKindMask)>>flagKindShift != uintptr(reflect.Int) {
|
||||||
|
flagKindShift = 0
|
||||||
|
flagRO = 1 << 5
|
||||||
|
flagIndir = 1 << 6
|
||||||
|
|
||||||
|
// Commit adf9b30e5594 modified the flags to separate the
|
||||||
|
// flagRO flag into two bits which specifies whether or not the
|
||||||
|
// field is embedded. This causes flagIndir to move over a bit
|
||||||
|
// and means that flagRO is the combination of either of the
|
||||||
|
// original flagRO bit and the new bit.
|
||||||
|
//
|
||||||
|
// This code detects the change by extracting what used to be
|
||||||
|
// the indirect bit to ensure it's set. When it's not, the flag
|
||||||
|
// order has been changed to the newer format, so the flags are
|
||||||
|
// updated accordingly.
|
||||||
|
if upfv&flagIndir == 0 {
|
||||||
|
flagRO = 3 << 5
|
||||||
|
flagIndir = 1 << 7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// unsafeReflectValue converts the passed reflect.Value into a one that bypasses
|
||||||
|
// the typical safety restrictions preventing access to unaddressable and
|
||||||
|
// unexported data. It works by digging the raw pointer to the underlying
|
||||||
|
// value out of the protected value and generating a new unprotected (unsafe)
|
||||||
|
// reflect.Value to it.
|
||||||
|
//
|
||||||
|
// This allows us to check for implementations of the Stringer and error
|
||||||
|
// interfaces to be used for pretty printing ordinarily unaddressable and
|
||||||
|
// inaccessible values such as unexported struct fields.
|
||||||
|
func unsafeReflectValue(v reflect.Value) (rv reflect.Value) {
|
||||||
|
indirects := 1
|
||||||
|
vt := v.Type()
|
||||||
|
upv := unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetPtr)
|
||||||
|
rvf := *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&v)) + offsetFlag))
|
||||||
|
if rvf&flagIndir != 0 {
|
||||||
|
vt = reflect.PtrTo(v.Type())
|
||||||
|
indirects++
|
||||||
|
} else if offsetScalar != 0 {
|
||||||
|
// The value is in the scalar field when it's not one of the
|
||||||
|
// reference types.
|
||||||
|
switch vt.Kind() {
|
||||||
|
case reflect.Uintptr:
|
||||||
|
case reflect.Chan:
|
||||||
|
case reflect.Func:
|
||||||
|
case reflect.Map:
|
||||||
|
case reflect.Ptr:
|
||||||
|
case reflect.UnsafePointer:
|
||||||
|
default:
|
||||||
|
upv = unsafe.Pointer(uintptr(unsafe.Pointer(&v)) +
|
||||||
|
offsetScalar)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pv := reflect.NewAt(vt, upv)
|
||||||
|
rv = pv
|
||||||
|
for i := 0; i < indirects; i++ {
|
||||||
|
rv = rv.Elem()
|
||||||
|
}
|
||||||
|
return rv
|
||||||
|
}
|
||||||
+38
@@ -0,0 +1,38 @@
|
|||||||
|
// Copyright (c) 2015-2016 Dave Collins <dave@davec.name>
|
||||||
|
//
|
||||||
|
// Permission to use, copy, modify, and distribute this software for any
|
||||||
|
// purpose with or without fee is hereby granted, provided that the above
|
||||||
|
// copyright notice and this permission notice appear in all copies.
|
||||||
|
//
|
||||||
|
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
|
||||||
|
// NOTE: Due to the following build constraints, this file will only be compiled
|
||||||
|
// when the code is running on Google App Engine, compiled by GopherJS, or
|
||||||
|
// "-tags safe" is added to the go build command line. The "disableunsafe"
|
||||||
|
// tag is deprecated and thus should not be used.
|
||||||
|
// +build js appengine safe disableunsafe
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import "reflect"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// UnsafeDisabled is a build-time constant which specifies whether or
|
||||||
|
// not access to the unsafe package is available.
|
||||||
|
UnsafeDisabled = true
|
||||||
|
)
|
||||||
|
|
||||||
|
// unsafeReflectValue typically converts the passed reflect.Value into a one
|
||||||
|
// that bypasses the typical safety restrictions preventing access to
|
||||||
|
// unaddressable and unexported data. However, doing this relies on access to
|
||||||
|
// the unsafe package. This is a stub version which simply returns the passed
|
||||||
|
// reflect.Value when the unsafe package is not available.
|
||||||
|
func unsafeReflectValue(v reflect.Value) reflect.Value {
|
||||||
|
return v
|
||||||
|
}
|
||||||
+341
@@ -0,0 +1,341 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strconv"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Some constants in the form of bytes to avoid string overhead. This mirrors
|
||||||
|
// the technique used in the fmt package.
|
||||||
|
var (
|
||||||
|
panicBytes = []byte("(PANIC=")
|
||||||
|
plusBytes = []byte("+")
|
||||||
|
iBytes = []byte("i")
|
||||||
|
trueBytes = []byte("true")
|
||||||
|
falseBytes = []byte("false")
|
||||||
|
interfaceBytes = []byte("(interface {})")
|
||||||
|
commaNewlineBytes = []byte(",\n")
|
||||||
|
newlineBytes = []byte("\n")
|
||||||
|
openBraceBytes = []byte("{")
|
||||||
|
openBraceNewlineBytes = []byte("{\n")
|
||||||
|
closeBraceBytes = []byte("}")
|
||||||
|
asteriskBytes = []byte("*")
|
||||||
|
colonBytes = []byte(":")
|
||||||
|
colonSpaceBytes = []byte(": ")
|
||||||
|
openParenBytes = []byte("(")
|
||||||
|
closeParenBytes = []byte(")")
|
||||||
|
spaceBytes = []byte(" ")
|
||||||
|
pointerChainBytes = []byte("->")
|
||||||
|
nilAngleBytes = []byte("<nil>")
|
||||||
|
maxNewlineBytes = []byte("<max depth reached>\n")
|
||||||
|
maxShortBytes = []byte("<max>")
|
||||||
|
circularBytes = []byte("<already shown>")
|
||||||
|
circularShortBytes = []byte("<shown>")
|
||||||
|
invalidAngleBytes = []byte("<invalid>")
|
||||||
|
openBracketBytes = []byte("[")
|
||||||
|
closeBracketBytes = []byte("]")
|
||||||
|
percentBytes = []byte("%")
|
||||||
|
precisionBytes = []byte(".")
|
||||||
|
openAngleBytes = []byte("<")
|
||||||
|
closeAngleBytes = []byte(">")
|
||||||
|
openMapBytes = []byte("map[")
|
||||||
|
closeMapBytes = []byte("]")
|
||||||
|
lenEqualsBytes = []byte("len=")
|
||||||
|
capEqualsBytes = []byte("cap=")
|
||||||
|
)
|
||||||
|
|
||||||
|
// hexDigits is used to map a decimal value to a hex digit.
|
||||||
|
var hexDigits = "0123456789abcdef"
|
||||||
|
|
||||||
|
// catchPanic handles any panics that might occur during the handleMethods
|
||||||
|
// calls.
|
||||||
|
func catchPanic(w io.Writer, v reflect.Value) {
|
||||||
|
if err := recover(); err != nil {
|
||||||
|
w.Write(panicBytes)
|
||||||
|
fmt.Fprintf(w, "%v", err)
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleMethods attempts to call the Error and String methods on the underlying
|
||||||
|
// type the passed reflect.Value represents and outputes the result to Writer w.
|
||||||
|
//
|
||||||
|
// It handles panics in any called methods by catching and displaying the error
|
||||||
|
// as the formatted value.
|
||||||
|
func handleMethods(cs *ConfigState, w io.Writer, v reflect.Value) (handled bool) {
|
||||||
|
// We need an interface to check if the type implements the error or
|
||||||
|
// Stringer interface. However, the reflect package won't give us an
|
||||||
|
// interface on certain things like unexported struct fields in order
|
||||||
|
// to enforce visibility rules. We use unsafe, when it's available,
|
||||||
|
// to bypass these restrictions since this package does not mutate the
|
||||||
|
// values.
|
||||||
|
if !v.CanInterface() {
|
||||||
|
if UnsafeDisabled {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
v = unsafeReflectValue(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Choose whether or not to do error and Stringer interface lookups against
|
||||||
|
// the base type or a pointer to the base type depending on settings.
|
||||||
|
// Technically calling one of these methods with a pointer receiver can
|
||||||
|
// mutate the value, however, types which choose to satisify an error or
|
||||||
|
// Stringer interface with a pointer receiver should not be mutating their
|
||||||
|
// state inside these interface methods.
|
||||||
|
if !cs.DisablePointerMethods && !UnsafeDisabled && !v.CanAddr() {
|
||||||
|
v = unsafeReflectValue(v)
|
||||||
|
}
|
||||||
|
if v.CanAddr() {
|
||||||
|
v = v.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Is it an error or Stringer?
|
||||||
|
switch iface := v.Interface().(type) {
|
||||||
|
case error:
|
||||||
|
defer catchPanic(w, v)
|
||||||
|
if cs.ContinueOnMethod {
|
||||||
|
w.Write(openParenBytes)
|
||||||
|
w.Write([]byte(iface.Error()))
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
w.Write(spaceBytes)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write([]byte(iface.Error()))
|
||||||
|
return true
|
||||||
|
|
||||||
|
case fmt.Stringer:
|
||||||
|
defer catchPanic(w, v)
|
||||||
|
if cs.ContinueOnMethod {
|
||||||
|
w.Write(openParenBytes)
|
||||||
|
w.Write([]byte(iface.String()))
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
w.Write(spaceBytes)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
w.Write([]byte(iface.String()))
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// printBool outputs a boolean value as true or false to Writer w.
|
||||||
|
func printBool(w io.Writer, val bool) {
|
||||||
|
if val {
|
||||||
|
w.Write(trueBytes)
|
||||||
|
} else {
|
||||||
|
w.Write(falseBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// printInt outputs a signed integer value to Writer w.
|
||||||
|
func printInt(w io.Writer, val int64, base int) {
|
||||||
|
w.Write([]byte(strconv.FormatInt(val, base)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// printUint outputs an unsigned integer value to Writer w.
|
||||||
|
func printUint(w io.Writer, val uint64, base int) {
|
||||||
|
w.Write([]byte(strconv.FormatUint(val, base)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// printFloat outputs a floating point value using the specified precision,
|
||||||
|
// which is expected to be 32 or 64bit, to Writer w.
|
||||||
|
func printFloat(w io.Writer, val float64, precision int) {
|
||||||
|
w.Write([]byte(strconv.FormatFloat(val, 'g', -1, precision)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// printComplex outputs a complex value using the specified float precision
|
||||||
|
// for the real and imaginary parts to Writer w.
|
||||||
|
func printComplex(w io.Writer, c complex128, floatPrecision int) {
|
||||||
|
r := real(c)
|
||||||
|
w.Write(openParenBytes)
|
||||||
|
w.Write([]byte(strconv.FormatFloat(r, 'g', -1, floatPrecision)))
|
||||||
|
i := imag(c)
|
||||||
|
if i >= 0 {
|
||||||
|
w.Write(plusBytes)
|
||||||
|
}
|
||||||
|
w.Write([]byte(strconv.FormatFloat(i, 'g', -1, floatPrecision)))
|
||||||
|
w.Write(iBytes)
|
||||||
|
w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// printHexPtr outputs a uintptr formatted as hexidecimal with a leading '0x'
|
||||||
|
// prefix to Writer w.
|
||||||
|
func printHexPtr(w io.Writer, p uintptr) {
|
||||||
|
// Null pointer.
|
||||||
|
num := uint64(p)
|
||||||
|
if num == 0 {
|
||||||
|
w.Write(nilAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Max uint64 is 16 bytes in hex + 2 bytes for '0x' prefix
|
||||||
|
buf := make([]byte, 18)
|
||||||
|
|
||||||
|
// It's simpler to construct the hex string right to left.
|
||||||
|
base := uint64(16)
|
||||||
|
i := len(buf) - 1
|
||||||
|
for num >= base {
|
||||||
|
buf[i] = hexDigits[num%base]
|
||||||
|
num /= base
|
||||||
|
i--
|
||||||
|
}
|
||||||
|
buf[i] = hexDigits[num]
|
||||||
|
|
||||||
|
// Add '0x' prefix.
|
||||||
|
i--
|
||||||
|
buf[i] = 'x'
|
||||||
|
i--
|
||||||
|
buf[i] = '0'
|
||||||
|
|
||||||
|
// Strip unused leading bytes.
|
||||||
|
buf = buf[i:]
|
||||||
|
w.Write(buf)
|
||||||
|
}
|
||||||
|
|
||||||
|
// valuesSorter implements sort.Interface to allow a slice of reflect.Value
|
||||||
|
// elements to be sorted.
|
||||||
|
type valuesSorter struct {
|
||||||
|
values []reflect.Value
|
||||||
|
strings []string // either nil or same len and values
|
||||||
|
cs *ConfigState
|
||||||
|
}
|
||||||
|
|
||||||
|
// newValuesSorter initializes a valuesSorter instance, which holds a set of
|
||||||
|
// surrogate keys on which the data should be sorted. It uses flags in
|
||||||
|
// ConfigState to decide if and how to populate those surrogate keys.
|
||||||
|
func newValuesSorter(values []reflect.Value, cs *ConfigState) sort.Interface {
|
||||||
|
vs := &valuesSorter{values: values, cs: cs}
|
||||||
|
if canSortSimply(vs.values[0].Kind()) {
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
if !cs.DisableMethods {
|
||||||
|
vs.strings = make([]string, len(values))
|
||||||
|
for i := range vs.values {
|
||||||
|
b := bytes.Buffer{}
|
||||||
|
if !handleMethods(cs, &b, vs.values[i]) {
|
||||||
|
vs.strings = nil
|
||||||
|
break
|
||||||
|
}
|
||||||
|
vs.strings[i] = b.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if vs.strings == nil && cs.SpewKeys {
|
||||||
|
vs.strings = make([]string, len(values))
|
||||||
|
for i := range vs.values {
|
||||||
|
vs.strings[i] = Sprintf("%#v", vs.values[i].Interface())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return vs
|
||||||
|
}
|
||||||
|
|
||||||
|
// canSortSimply tests whether a reflect.Kind is a primitive that can be sorted
|
||||||
|
// directly, or whether it should be considered for sorting by surrogate keys
|
||||||
|
// (if the ConfigState allows it).
|
||||||
|
func canSortSimply(kind reflect.Kind) bool {
|
||||||
|
// This switch parallels valueSortLess, except for the default case.
|
||||||
|
switch kind {
|
||||||
|
case reflect.Bool:
|
||||||
|
return true
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
return true
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
return true
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return true
|
||||||
|
case reflect.String:
|
||||||
|
return true
|
||||||
|
case reflect.Uintptr:
|
||||||
|
return true
|
||||||
|
case reflect.Array:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Len returns the number of values in the slice. It is part of the
|
||||||
|
// sort.Interface implementation.
|
||||||
|
func (s *valuesSorter) Len() int {
|
||||||
|
return len(s.values)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Swap swaps the values at the passed indices. It is part of the
|
||||||
|
// sort.Interface implementation.
|
||||||
|
func (s *valuesSorter) Swap(i, j int) {
|
||||||
|
s.values[i], s.values[j] = s.values[j], s.values[i]
|
||||||
|
if s.strings != nil {
|
||||||
|
s.strings[i], s.strings[j] = s.strings[j], s.strings[i]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// valueSortLess returns whether the first value should sort before the second
|
||||||
|
// value. It is used by valueSorter.Less as part of the sort.Interface
|
||||||
|
// implementation.
|
||||||
|
func valueSortLess(a, b reflect.Value) bool {
|
||||||
|
switch a.Kind() {
|
||||||
|
case reflect.Bool:
|
||||||
|
return !a.Bool() && b.Bool()
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
return a.Int() < b.Int()
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
return a.Uint() < b.Uint()
|
||||||
|
case reflect.Float32, reflect.Float64:
|
||||||
|
return a.Float() < b.Float()
|
||||||
|
case reflect.String:
|
||||||
|
return a.String() < b.String()
|
||||||
|
case reflect.Uintptr:
|
||||||
|
return a.Uint() < b.Uint()
|
||||||
|
case reflect.Array:
|
||||||
|
// Compare the contents of both arrays.
|
||||||
|
l := a.Len()
|
||||||
|
for i := 0; i < l; i++ {
|
||||||
|
av := a.Index(i)
|
||||||
|
bv := b.Index(i)
|
||||||
|
if av.Interface() == bv.Interface() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return valueSortLess(av, bv)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return a.String() < b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Less returns whether the value at index i should sort before the
|
||||||
|
// value at index j. It is part of the sort.Interface implementation.
|
||||||
|
func (s *valuesSorter) Less(i, j int) bool {
|
||||||
|
if s.strings == nil {
|
||||||
|
return valueSortLess(s.values[i], s.values[j])
|
||||||
|
}
|
||||||
|
return s.strings[i] < s.strings[j]
|
||||||
|
}
|
||||||
|
|
||||||
|
// sortValues is a sort function that handles both native types and any type that
|
||||||
|
// can be converted to error or Stringer. Other inputs are sorted according to
|
||||||
|
// their Value.String() value to ensure display stability.
|
||||||
|
func sortValues(values []reflect.Value, cs *ConfigState) {
|
||||||
|
if len(values) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sort.Sort(newValuesSorter(values, cs))
|
||||||
|
}
|
||||||
+306
@@ -0,0 +1,306 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConfigState houses the configuration options used by spew to format and
|
||||||
|
// display values. There is a global instance, Config, that is used to control
|
||||||
|
// all top-level Formatter and Dump functionality. Each ConfigState instance
|
||||||
|
// provides methods equivalent to the top-level functions.
|
||||||
|
//
|
||||||
|
// The zero value for ConfigState provides no indentation. You would typically
|
||||||
|
// want to set it to a space or a tab.
|
||||||
|
//
|
||||||
|
// Alternatively, you can use NewDefaultConfig to get a ConfigState instance
|
||||||
|
// with default settings. See the documentation of NewDefaultConfig for default
|
||||||
|
// values.
|
||||||
|
type ConfigState struct {
|
||||||
|
// Indent specifies the string to use for each indentation level. The
|
||||||
|
// global config instance that all top-level functions use set this to a
|
||||||
|
// single space by default. If you would like more indentation, you might
|
||||||
|
// set this to a tab with "\t" or perhaps two spaces with " ".
|
||||||
|
Indent string
|
||||||
|
|
||||||
|
// MaxDepth controls the maximum number of levels to descend into nested
|
||||||
|
// data structures. The default, 0, means there is no limit.
|
||||||
|
//
|
||||||
|
// NOTE: Circular data structures are properly detected, so it is not
|
||||||
|
// necessary to set this value unless you specifically want to limit deeply
|
||||||
|
// nested data structures.
|
||||||
|
MaxDepth int
|
||||||
|
|
||||||
|
// DisableMethods specifies whether or not error and Stringer interfaces are
|
||||||
|
// invoked for types that implement them.
|
||||||
|
DisableMethods bool
|
||||||
|
|
||||||
|
// DisablePointerMethods specifies whether or not to check for and invoke
|
||||||
|
// error and Stringer interfaces on types which only accept a pointer
|
||||||
|
// receiver when the current type is not a pointer.
|
||||||
|
//
|
||||||
|
// NOTE: This might be an unsafe action since calling one of these methods
|
||||||
|
// with a pointer receiver could technically mutate the value, however,
|
||||||
|
// in practice, types which choose to satisify an error or Stringer
|
||||||
|
// interface with a pointer receiver should not be mutating their state
|
||||||
|
// inside these interface methods. As a result, this option relies on
|
||||||
|
// access to the unsafe package, so it will not have any effect when
|
||||||
|
// running in environments without access to the unsafe package such as
|
||||||
|
// Google App Engine or with the "safe" build tag specified.
|
||||||
|
DisablePointerMethods bool
|
||||||
|
|
||||||
|
// DisablePointerAddresses specifies whether to disable the printing of
|
||||||
|
// pointer addresses. This is useful when diffing data structures in tests.
|
||||||
|
DisablePointerAddresses bool
|
||||||
|
|
||||||
|
// DisableCapacities specifies whether to disable the printing of capacities
|
||||||
|
// for arrays, slices, maps and channels. This is useful when diffing
|
||||||
|
// data structures in tests.
|
||||||
|
DisableCapacities bool
|
||||||
|
|
||||||
|
// ContinueOnMethod specifies whether or not recursion should continue once
|
||||||
|
// a custom error or Stringer interface is invoked. The default, false,
|
||||||
|
// means it will print the results of invoking the custom error or Stringer
|
||||||
|
// interface and return immediately instead of continuing to recurse into
|
||||||
|
// the internals of the data type.
|
||||||
|
//
|
||||||
|
// NOTE: This flag does not have any effect if method invocation is disabled
|
||||||
|
// via the DisableMethods or DisablePointerMethods options.
|
||||||
|
ContinueOnMethod bool
|
||||||
|
|
||||||
|
// SortKeys specifies map keys should be sorted before being printed. Use
|
||||||
|
// this to have a more deterministic, diffable output. Note that only
|
||||||
|
// native types (bool, int, uint, floats, uintptr and string) and types
|
||||||
|
// that support the error or Stringer interfaces (if methods are
|
||||||
|
// enabled) are supported, with other types sorted according to the
|
||||||
|
// reflect.Value.String() output which guarantees display stability.
|
||||||
|
SortKeys bool
|
||||||
|
|
||||||
|
// SpewKeys specifies that, as a last resort attempt, map keys should
|
||||||
|
// be spewed to strings and sorted by those strings. This is only
|
||||||
|
// considered if SortKeys is true.
|
||||||
|
SpewKeys bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the active configuration of the top-level functions.
|
||||||
|
// The configuration can be changed by modifying the contents of spew.Config.
|
||||||
|
var Config = ConfigState{Indent: " "}
|
||||||
|
|
||||||
|
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the formatted string as a value that satisfies error. See NewFormatter
|
||||||
|
// for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Errorf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Errorf(format string, a ...interface{}) (err error) {
|
||||||
|
return fmt.Errorf(format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprint(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprint(w, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintf(w, format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintf(w, format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintln(w, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintln(w, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Print(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Print(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Print(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Printf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Printf(format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Printf(format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Println(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Println(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Println(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprint(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Sprint(a ...interface{}) string {
|
||||||
|
return fmt.Sprint(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||||
|
// passed with a Formatter interface returned by c.NewFormatter. It returns
|
||||||
|
// the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintf(format, c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Sprintf(format string, a ...interface{}) string {
|
||||||
|
return fmt.Sprintf(format, c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||||
|
// were passed with a Formatter interface returned by c.NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintln(c.NewFormatter(a), c.NewFormatter(b))
|
||||||
|
func (c *ConfigState) Sprintln(a ...interface{}) string {
|
||||||
|
return fmt.Sprintln(c.convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||||
|
interface. As a result, it integrates cleanly with standard fmt package
|
||||||
|
printing functions. The formatter is useful for inline printing of smaller data
|
||||||
|
types similar to the standard %v format specifier.
|
||||||
|
|
||||||
|
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||||
|
addresses), %#v (adds types), and %#+v (adds types and pointer addresses) verb
|
||||||
|
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||||
|
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||||
|
the width and precision arguments (however they will still work on the format
|
||||||
|
specifiers not handled by the custom formatter).
|
||||||
|
|
||||||
|
Typically this function shouldn't be called directly. It is much easier to make
|
||||||
|
use of the custom formatter by calling one of the convenience functions such as
|
||||||
|
c.Printf, c.Println, or c.Printf.
|
||||||
|
*/
|
||||||
|
func (c *ConfigState) NewFormatter(v interface{}) fmt.Formatter {
|
||||||
|
return newFormatter(c, v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||||
|
// exactly the same as Dump.
|
||||||
|
func (c *ConfigState) Fdump(w io.Writer, a ...interface{}) {
|
||||||
|
fdump(c, w, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Dump displays the passed parameters to standard out with newlines, customizable
|
||||||
|
indentation, and additional debug information such as complete types and all
|
||||||
|
pointer addresses used to indirect to the final value. It provides the
|
||||||
|
following features over the built-in printing facilities provided by the fmt
|
||||||
|
package:
|
||||||
|
|
||||||
|
* Pointers are dereferenced and followed
|
||||||
|
* Circular data structures are detected and handled properly
|
||||||
|
* Custom Stringer/error interfaces are optionally invoked, including
|
||||||
|
on unexported types
|
||||||
|
* Custom types which only implement the Stringer/error interfaces via
|
||||||
|
a pointer receiver are optionally invoked when passing non-pointer
|
||||||
|
variables
|
||||||
|
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||||
|
includes offsets, byte values in hex, and ASCII output
|
||||||
|
|
||||||
|
The configuration options are controlled by modifying the public members
|
||||||
|
of c. See ConfigState for options documentation.
|
||||||
|
|
||||||
|
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||||
|
get the formatted result as a string.
|
||||||
|
*/
|
||||||
|
func (c *ConfigState) Dump(a ...interface{}) {
|
||||||
|
fdump(c, os.Stdout, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||||
|
// as Dump.
|
||||||
|
func (c *ConfigState) Sdump(a ...interface{}) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fdump(c, &buf, a...)
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||||
|
// length with each argument converted to a spew Formatter interface using
|
||||||
|
// the ConfigState associated with s.
|
||||||
|
func (c *ConfigState) convertArgs(args []interface{}) (formatters []interface{}) {
|
||||||
|
formatters = make([]interface{}, len(args))
|
||||||
|
for index, arg := range args {
|
||||||
|
formatters[index] = newFormatter(c, arg)
|
||||||
|
}
|
||||||
|
return formatters
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDefaultConfig returns a ConfigState with the following default settings.
|
||||||
|
//
|
||||||
|
// Indent: " "
|
||||||
|
// MaxDepth: 0
|
||||||
|
// DisableMethods: false
|
||||||
|
// DisablePointerMethods: false
|
||||||
|
// ContinueOnMethod: false
|
||||||
|
// SortKeys: false
|
||||||
|
func NewDefaultConfig() *ConfigState {
|
||||||
|
return &ConfigState{Indent: " "}
|
||||||
|
}
|
||||||
+211
@@ -0,0 +1,211 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
/*
|
||||||
|
Package spew implements a deep pretty printer for Go data structures to aid in
|
||||||
|
debugging.
|
||||||
|
|
||||||
|
A quick overview of the additional features spew provides over the built-in
|
||||||
|
printing facilities for Go data types are as follows:
|
||||||
|
|
||||||
|
* Pointers are dereferenced and followed
|
||||||
|
* Circular data structures are detected and handled properly
|
||||||
|
* Custom Stringer/error interfaces are optionally invoked, including
|
||||||
|
on unexported types
|
||||||
|
* Custom types which only implement the Stringer/error interfaces via
|
||||||
|
a pointer receiver are optionally invoked when passing non-pointer
|
||||||
|
variables
|
||||||
|
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||||
|
includes offsets, byte values in hex, and ASCII output (only when using
|
||||||
|
Dump style)
|
||||||
|
|
||||||
|
There are two different approaches spew allows for dumping Go data structures:
|
||||||
|
|
||||||
|
* Dump style which prints with newlines, customizable indentation,
|
||||||
|
and additional debug information such as types and all pointer addresses
|
||||||
|
used to indirect to the final value
|
||||||
|
* A custom Formatter interface that integrates cleanly with the standard fmt
|
||||||
|
package and replaces %v, %+v, %#v, and %#+v to provide inline printing
|
||||||
|
similar to the default %v while providing the additional functionality
|
||||||
|
outlined above and passing unsupported format verbs such as %x and %q
|
||||||
|
along to fmt
|
||||||
|
|
||||||
|
Quick Start
|
||||||
|
|
||||||
|
This section demonstrates how to quickly get started with spew. See the
|
||||||
|
sections below for further details on formatting and configuration options.
|
||||||
|
|
||||||
|
To dump a variable with full newlines, indentation, type, and pointer
|
||||||
|
information use Dump, Fdump, or Sdump:
|
||||||
|
spew.Dump(myVar1, myVar2, ...)
|
||||||
|
spew.Fdump(someWriter, myVar1, myVar2, ...)
|
||||||
|
str := spew.Sdump(myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
Alternatively, if you would prefer to use format strings with a compacted inline
|
||||||
|
printing style, use the convenience wrappers Printf, Fprintf, etc with
|
||||||
|
%v (most compact), %+v (adds pointer addresses), %#v (adds types), or
|
||||||
|
%#+v (adds types and pointer addresses):
|
||||||
|
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
spew.Fprintf(someWriter, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Fprintf(someWriter, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
|
||||||
|
Configuration Options
|
||||||
|
|
||||||
|
Configuration of spew is handled by fields in the ConfigState type. For
|
||||||
|
convenience, all of the top-level functions use a global state available
|
||||||
|
via the spew.Config global.
|
||||||
|
|
||||||
|
It is also possible to create a ConfigState instance that provides methods
|
||||||
|
equivalent to the top-level functions. This allows concurrent configuration
|
||||||
|
options. See the ConfigState documentation for more details.
|
||||||
|
|
||||||
|
The following configuration options are available:
|
||||||
|
* Indent
|
||||||
|
String to use for each indentation level for Dump functions.
|
||||||
|
It is a single space by default. A popular alternative is "\t".
|
||||||
|
|
||||||
|
* MaxDepth
|
||||||
|
Maximum number of levels to descend into nested data structures.
|
||||||
|
There is no limit by default.
|
||||||
|
|
||||||
|
* DisableMethods
|
||||||
|
Disables invocation of error and Stringer interface methods.
|
||||||
|
Method invocation is enabled by default.
|
||||||
|
|
||||||
|
* DisablePointerMethods
|
||||||
|
Disables invocation of error and Stringer interface methods on types
|
||||||
|
which only accept pointer receivers from non-pointer variables.
|
||||||
|
Pointer method invocation is enabled by default.
|
||||||
|
|
||||||
|
* DisablePointerAddresses
|
||||||
|
DisablePointerAddresses specifies whether to disable the printing of
|
||||||
|
pointer addresses. This is useful when diffing data structures in tests.
|
||||||
|
|
||||||
|
* DisableCapacities
|
||||||
|
DisableCapacities specifies whether to disable the printing of
|
||||||
|
capacities for arrays, slices, maps and channels. This is useful when
|
||||||
|
diffing data structures in tests.
|
||||||
|
|
||||||
|
* ContinueOnMethod
|
||||||
|
Enables recursion into types after invoking error and Stringer interface
|
||||||
|
methods. Recursion after method invocation is disabled by default.
|
||||||
|
|
||||||
|
* SortKeys
|
||||||
|
Specifies map keys should be sorted before being printed. Use
|
||||||
|
this to have a more deterministic, diffable output. Note that
|
||||||
|
only native types (bool, int, uint, floats, uintptr and string)
|
||||||
|
and types which implement error or Stringer interfaces are
|
||||||
|
supported with other types sorted according to the
|
||||||
|
reflect.Value.String() output which guarantees display
|
||||||
|
stability. Natural map order is used by default.
|
||||||
|
|
||||||
|
* SpewKeys
|
||||||
|
Specifies that, as a last resort attempt, map keys should be
|
||||||
|
spewed to strings and sorted by those strings. This is only
|
||||||
|
considered if SortKeys is true.
|
||||||
|
|
||||||
|
Dump Usage
|
||||||
|
|
||||||
|
Simply call spew.Dump with a list of variables you want to dump:
|
||||||
|
|
||||||
|
spew.Dump(myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
You may also call spew.Fdump if you would prefer to output to an arbitrary
|
||||||
|
io.Writer. For example, to dump to standard error:
|
||||||
|
|
||||||
|
spew.Fdump(os.Stderr, myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
A third option is to call spew.Sdump to get the formatted output as a string:
|
||||||
|
|
||||||
|
str := spew.Sdump(myVar1, myVar2, ...)
|
||||||
|
|
||||||
|
Sample Dump Output
|
||||||
|
|
||||||
|
See the Dump example for details on the setup of the types and variables being
|
||||||
|
shown here.
|
||||||
|
|
||||||
|
(main.Foo) {
|
||||||
|
unexportedField: (*main.Bar)(0xf84002e210)({
|
||||||
|
flag: (main.Flag) flagTwo,
|
||||||
|
data: (uintptr) <nil>
|
||||||
|
}),
|
||||||
|
ExportedField: (map[interface {}]interface {}) (len=1) {
|
||||||
|
(string) (len=3) "one": (bool) true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Byte (and uint8) arrays and slices are displayed uniquely like the hexdump -C
|
||||||
|
command as shown.
|
||||||
|
([]uint8) (len=32 cap=32) {
|
||||||
|
00000000 11 12 13 14 15 16 17 18 19 1a 1b 1c 1d 1e 1f 20 |............... |
|
||||||
|
00000010 21 22 23 24 25 26 27 28 29 2a 2b 2c 2d 2e 2f 30 |!"#$%&'()*+,-./0|
|
||||||
|
00000020 31 32 |12|
|
||||||
|
}
|
||||||
|
|
||||||
|
Custom Formatter
|
||||||
|
|
||||||
|
Spew provides a custom formatter that implements the fmt.Formatter interface
|
||||||
|
so that it integrates cleanly with standard fmt package printing functions. The
|
||||||
|
formatter is useful for inline printing of smaller data types similar to the
|
||||||
|
standard %v format specifier.
|
||||||
|
|
||||||
|
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||||
|
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||||
|
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||||
|
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||||
|
the width and precision arguments (however they will still work on the format
|
||||||
|
specifiers not handled by the custom formatter).
|
||||||
|
|
||||||
|
Custom Formatter Usage
|
||||||
|
|
||||||
|
The simplest way to make use of the spew custom formatter is to call one of the
|
||||||
|
convenience functions such as spew.Printf, spew.Println, or spew.Printf. The
|
||||||
|
functions have syntax you are most likely already familiar with:
|
||||||
|
|
||||||
|
spew.Printf("myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Printf("myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
spew.Println(myVar, myVar2)
|
||||||
|
spew.Fprintf(os.Stderr, "myVar1: %v -- myVar2: %+v", myVar1, myVar2)
|
||||||
|
spew.Fprintf(os.Stderr, "myVar3: %#v -- myVar4: %#+v", myVar3, myVar4)
|
||||||
|
|
||||||
|
See the Index for the full list convenience functions.
|
||||||
|
|
||||||
|
Sample Formatter Output
|
||||||
|
|
||||||
|
Double pointer to a uint8:
|
||||||
|
%v: <**>5
|
||||||
|
%+v: <**>(0xf8400420d0->0xf8400420c8)5
|
||||||
|
%#v: (**uint8)5
|
||||||
|
%#+v: (**uint8)(0xf8400420d0->0xf8400420c8)5
|
||||||
|
|
||||||
|
Pointer to circular struct with a uint8 field and a pointer to itself:
|
||||||
|
%v: <*>{1 <*><shown>}
|
||||||
|
%+v: <*>(0xf84003e260){ui8:1 c:<*>(0xf84003e260)<shown>}
|
||||||
|
%#v: (*main.circular){ui8:(uint8)1 c:(*main.circular)<shown>}
|
||||||
|
%#+v: (*main.circular)(0xf84003e260){ui8:(uint8)1 c:(*main.circular)(0xf84003e260)<shown>}
|
||||||
|
|
||||||
|
See the Printf example for details on the setup of variables being shown
|
||||||
|
here.
|
||||||
|
|
||||||
|
Errors
|
||||||
|
|
||||||
|
Since it is possible for custom Stringer/error interfaces to panic, spew
|
||||||
|
detects them and handles them internally by printing the panic information
|
||||||
|
inline with the output. Since spew is intended to provide deep pretty printing
|
||||||
|
capabilities on structures, it intentionally does not return any errors.
|
||||||
|
*/
|
||||||
|
package spew
|
||||||
+509
@@ -0,0 +1,509 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// uint8Type is a reflect.Type representing a uint8. It is used to
|
||||||
|
// convert cgo types to uint8 slices for hexdumping.
|
||||||
|
uint8Type = reflect.TypeOf(uint8(0))
|
||||||
|
|
||||||
|
// cCharRE is a regular expression that matches a cgo char.
|
||||||
|
// It is used to detect character arrays to hexdump them.
|
||||||
|
cCharRE = regexp.MustCompile("^.*\\._Ctype_char$")
|
||||||
|
|
||||||
|
// cUnsignedCharRE is a regular expression that matches a cgo unsigned
|
||||||
|
// char. It is used to detect unsigned character arrays to hexdump
|
||||||
|
// them.
|
||||||
|
cUnsignedCharRE = regexp.MustCompile("^.*\\._Ctype_unsignedchar$")
|
||||||
|
|
||||||
|
// cUint8tCharRE is a regular expression that matches a cgo uint8_t.
|
||||||
|
// It is used to detect uint8_t arrays to hexdump them.
|
||||||
|
cUint8tCharRE = regexp.MustCompile("^.*\\._Ctype_uint8_t$")
|
||||||
|
)
|
||||||
|
|
||||||
|
// dumpState contains information about the state of a dump operation.
|
||||||
|
type dumpState struct {
|
||||||
|
w io.Writer
|
||||||
|
depth int
|
||||||
|
pointers map[uintptr]int
|
||||||
|
ignoreNextType bool
|
||||||
|
ignoreNextIndent bool
|
||||||
|
cs *ConfigState
|
||||||
|
}
|
||||||
|
|
||||||
|
// indent performs indentation according to the depth level and cs.Indent
|
||||||
|
// option.
|
||||||
|
func (d *dumpState) indent() {
|
||||||
|
if d.ignoreNextIndent {
|
||||||
|
d.ignoreNextIndent = false
|
||||||
|
return
|
||||||
|
}
|
||||||
|
d.w.Write(bytes.Repeat([]byte(d.cs.Indent), d.depth))
|
||||||
|
}
|
||||||
|
|
||||||
|
// unpackValue returns values inside of non-nil interfaces when possible.
|
||||||
|
// This is useful for data types like structs, arrays, slices, and maps which
|
||||||
|
// can contain varying types packed inside an interface.
|
||||||
|
func (d *dumpState) unpackValue(v reflect.Value) reflect.Value {
|
||||||
|
if v.Kind() == reflect.Interface && !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// dumpPtr handles formatting of pointers by indirecting them as necessary.
|
||||||
|
func (d *dumpState) dumpPtr(v reflect.Value) {
|
||||||
|
// Remove pointers at or below the current depth from map used to detect
|
||||||
|
// circular refs.
|
||||||
|
for k, depth := range d.pointers {
|
||||||
|
if depth >= d.depth {
|
||||||
|
delete(d.pointers, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep list of all dereferenced pointers to show later.
|
||||||
|
pointerChain := make([]uintptr, 0)
|
||||||
|
|
||||||
|
// Figure out how many levels of indirection there are by dereferencing
|
||||||
|
// pointers and unpacking interfaces down the chain while detecting circular
|
||||||
|
// references.
|
||||||
|
nilFound := false
|
||||||
|
cycleFound := false
|
||||||
|
indirects := 0
|
||||||
|
ve := v
|
||||||
|
for ve.Kind() == reflect.Ptr {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
indirects++
|
||||||
|
addr := ve.Pointer()
|
||||||
|
pointerChain = append(pointerChain, addr)
|
||||||
|
if pd, ok := d.pointers[addr]; ok && pd < d.depth {
|
||||||
|
cycleFound = true
|
||||||
|
indirects--
|
||||||
|
break
|
||||||
|
}
|
||||||
|
d.pointers[addr] = d.depth
|
||||||
|
|
||||||
|
ve = ve.Elem()
|
||||||
|
if ve.Kind() == reflect.Interface {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ve = ve.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display type information.
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
d.w.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||||
|
d.w.Write([]byte(ve.Type().String()))
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
|
||||||
|
// Display pointer information.
|
||||||
|
if !d.cs.DisablePointerAddresses && len(pointerChain) > 0 {
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
for i, addr := range pointerChain {
|
||||||
|
if i > 0 {
|
||||||
|
d.w.Write(pointerChainBytes)
|
||||||
|
}
|
||||||
|
printHexPtr(d.w, addr)
|
||||||
|
}
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display dereferenced value.
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
switch {
|
||||||
|
case nilFound == true:
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
|
||||||
|
case cycleFound == true:
|
||||||
|
d.w.Write(circularBytes)
|
||||||
|
|
||||||
|
default:
|
||||||
|
d.ignoreNextType = true
|
||||||
|
d.dump(ve)
|
||||||
|
}
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// dumpSlice handles formatting of arrays and slices. Byte (uint8 under
|
||||||
|
// reflection) arrays and slices are dumped in hexdump -C fashion.
|
||||||
|
func (d *dumpState) dumpSlice(v reflect.Value) {
|
||||||
|
// Determine whether this type should be hex dumped or not. Also,
|
||||||
|
// for types which should be hexdumped, try to use the underlying data
|
||||||
|
// first, then fall back to trying to convert them to a uint8 slice.
|
||||||
|
var buf []uint8
|
||||||
|
doConvert := false
|
||||||
|
doHexDump := false
|
||||||
|
numEntries := v.Len()
|
||||||
|
if numEntries > 0 {
|
||||||
|
vt := v.Index(0).Type()
|
||||||
|
vts := vt.String()
|
||||||
|
switch {
|
||||||
|
// C types that need to be converted.
|
||||||
|
case cCharRE.MatchString(vts):
|
||||||
|
fallthrough
|
||||||
|
case cUnsignedCharRE.MatchString(vts):
|
||||||
|
fallthrough
|
||||||
|
case cUint8tCharRE.MatchString(vts):
|
||||||
|
doConvert = true
|
||||||
|
|
||||||
|
// Try to use existing uint8 slices and fall back to converting
|
||||||
|
// and copying if that fails.
|
||||||
|
case vt.Kind() == reflect.Uint8:
|
||||||
|
// We need an addressable interface to convert the type
|
||||||
|
// to a byte slice. However, the reflect package won't
|
||||||
|
// give us an interface on certain things like
|
||||||
|
// unexported struct fields in order to enforce
|
||||||
|
// visibility rules. We use unsafe, when available, to
|
||||||
|
// bypass these restrictions since this package does not
|
||||||
|
// mutate the values.
|
||||||
|
vs := v
|
||||||
|
if !vs.CanInterface() || !vs.CanAddr() {
|
||||||
|
vs = unsafeReflectValue(vs)
|
||||||
|
}
|
||||||
|
if !UnsafeDisabled {
|
||||||
|
vs = vs.Slice(0, numEntries)
|
||||||
|
|
||||||
|
// Use the existing uint8 slice if it can be
|
||||||
|
// type asserted.
|
||||||
|
iface := vs.Interface()
|
||||||
|
if slice, ok := iface.([]uint8); ok {
|
||||||
|
buf = slice
|
||||||
|
doHexDump = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The underlying data needs to be converted if it can't
|
||||||
|
// be type asserted to a uint8 slice.
|
||||||
|
doConvert = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copy and convert the underlying type if needed.
|
||||||
|
if doConvert && vt.ConvertibleTo(uint8Type) {
|
||||||
|
// Convert and copy each element into a uint8 byte
|
||||||
|
// slice.
|
||||||
|
buf = make([]uint8, numEntries)
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
vv := v.Index(i)
|
||||||
|
buf[i] = uint8(vv.Convert(uint8Type).Uint())
|
||||||
|
}
|
||||||
|
doHexDump = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hexdump the entire slice as needed.
|
||||||
|
if doHexDump {
|
||||||
|
indent := strings.Repeat(d.cs.Indent, d.depth)
|
||||||
|
str := indent + hex.Dump(buf)
|
||||||
|
str = strings.Replace(str, "\n", "\n"+indent, -1)
|
||||||
|
str = strings.TrimRight(str, d.cs.Indent)
|
||||||
|
d.w.Write([]byte(str))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Recursively call dump for each item.
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
d.dump(d.unpackValue(v.Index(i)))
|
||||||
|
if i < (numEntries - 1) {
|
||||||
|
d.w.Write(commaNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// dump is the main workhorse for dumping a value. It uses the passed reflect
|
||||||
|
// value to figure out what kind of object we are dealing with and formats it
|
||||||
|
// appropriately. It is a recursive function, however circular data structures
|
||||||
|
// are detected and handled properly.
|
||||||
|
func (d *dumpState) dump(v reflect.Value) {
|
||||||
|
// Handle invalid reflect values immediately.
|
||||||
|
kind := v.Kind()
|
||||||
|
if kind == reflect.Invalid {
|
||||||
|
d.w.Write(invalidAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pointers specially.
|
||||||
|
if kind == reflect.Ptr {
|
||||||
|
d.indent()
|
||||||
|
d.dumpPtr(v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print type information unless already handled elsewhere.
|
||||||
|
if !d.ignoreNextType {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
d.w.Write([]byte(v.Type().String()))
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
d.w.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
d.ignoreNextType = false
|
||||||
|
|
||||||
|
// Display length and capacity if the built-in len and cap functions
|
||||||
|
// work with the value's kind and the len/cap itself is non-zero.
|
||||||
|
valueLen, valueCap := 0, 0
|
||||||
|
switch v.Kind() {
|
||||||
|
case reflect.Array, reflect.Slice, reflect.Chan:
|
||||||
|
valueLen, valueCap = v.Len(), v.Cap()
|
||||||
|
case reflect.Map, reflect.String:
|
||||||
|
valueLen = v.Len()
|
||||||
|
}
|
||||||
|
if valueLen != 0 || !d.cs.DisableCapacities && valueCap != 0 {
|
||||||
|
d.w.Write(openParenBytes)
|
||||||
|
if valueLen != 0 {
|
||||||
|
d.w.Write(lenEqualsBytes)
|
||||||
|
printInt(d.w, int64(valueLen), 10)
|
||||||
|
}
|
||||||
|
if !d.cs.DisableCapacities && valueCap != 0 {
|
||||||
|
if valueLen != 0 {
|
||||||
|
d.w.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
d.w.Write(capEqualsBytes)
|
||||||
|
printInt(d.w, int64(valueCap), 10)
|
||||||
|
}
|
||||||
|
d.w.Write(closeParenBytes)
|
||||||
|
d.w.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call Stringer/error interfaces if they exist and the handle methods flag
|
||||||
|
// is enabled
|
||||||
|
if !d.cs.DisableMethods {
|
||||||
|
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||||
|
if handled := handleMethods(d.cs, d.w, v); handled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.Invalid:
|
||||||
|
// Do nothing. We should never get here since invalid has already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Bool:
|
||||||
|
printBool(d.w, v.Bool())
|
||||||
|
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
printInt(d.w, v.Int(), 10)
|
||||||
|
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
printUint(d.w, v.Uint(), 10)
|
||||||
|
|
||||||
|
case reflect.Float32:
|
||||||
|
printFloat(d.w, v.Float(), 32)
|
||||||
|
|
||||||
|
case reflect.Float64:
|
||||||
|
printFloat(d.w, v.Float(), 64)
|
||||||
|
|
||||||
|
case reflect.Complex64:
|
||||||
|
printComplex(d.w, v.Complex(), 32)
|
||||||
|
|
||||||
|
case reflect.Complex128:
|
||||||
|
printComplex(d.w, v.Complex(), 64)
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if v.IsNil() {
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
d.w.Write(openBraceNewlineBytes)
|
||||||
|
d.depth++
|
||||||
|
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(maxNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.dumpSlice(v)
|
||||||
|
}
|
||||||
|
d.depth--
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
d.w.Write([]byte(strconv.Quote(v.String())))
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
// The only time we should get here is for nil interfaces due to
|
||||||
|
// unpackValue calls.
|
||||||
|
if v.IsNil() {
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Ptr:
|
||||||
|
// Do nothing. We should never get here since pointers have already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
// nil maps should be indicated as different than empty maps
|
||||||
|
if v.IsNil() {
|
||||||
|
d.w.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
d.w.Write(openBraceNewlineBytes)
|
||||||
|
d.depth++
|
||||||
|
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(maxNewlineBytes)
|
||||||
|
} else {
|
||||||
|
numEntries := v.Len()
|
||||||
|
keys := v.MapKeys()
|
||||||
|
if d.cs.SortKeys {
|
||||||
|
sortValues(keys, d.cs)
|
||||||
|
}
|
||||||
|
for i, key := range keys {
|
||||||
|
d.dump(d.unpackValue(key))
|
||||||
|
d.w.Write(colonSpaceBytes)
|
||||||
|
d.ignoreNextIndent = true
|
||||||
|
d.dump(d.unpackValue(v.MapIndex(key)))
|
||||||
|
if i < (numEntries - 1) {
|
||||||
|
d.w.Write(commaNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.depth--
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
d.w.Write(openBraceNewlineBytes)
|
||||||
|
d.depth++
|
||||||
|
if (d.cs.MaxDepth != 0) && (d.depth > d.cs.MaxDepth) {
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(maxNewlineBytes)
|
||||||
|
} else {
|
||||||
|
vt := v.Type()
|
||||||
|
numFields := v.NumField()
|
||||||
|
for i := 0; i < numFields; i++ {
|
||||||
|
d.indent()
|
||||||
|
vtf := vt.Field(i)
|
||||||
|
d.w.Write([]byte(vtf.Name))
|
||||||
|
d.w.Write(colonSpaceBytes)
|
||||||
|
d.ignoreNextIndent = true
|
||||||
|
d.dump(d.unpackValue(v.Field(i)))
|
||||||
|
if i < (numFields - 1) {
|
||||||
|
d.w.Write(commaNewlineBytes)
|
||||||
|
} else {
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
d.depth--
|
||||||
|
d.indent()
|
||||||
|
d.w.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.Uintptr:
|
||||||
|
printHexPtr(d.w, uintptr(v.Uint()))
|
||||||
|
|
||||||
|
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||||
|
printHexPtr(d.w, v.Pointer())
|
||||||
|
|
||||||
|
// There were not any other types at the time this code was written, but
|
||||||
|
// fall back to letting the default fmt package handle it in case any new
|
||||||
|
// types are added.
|
||||||
|
default:
|
||||||
|
if v.CanInterface() {
|
||||||
|
fmt.Fprintf(d.w, "%v", v.Interface())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(d.w, "%v", v.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// fdump is a helper function to consolidate the logic from the various public
|
||||||
|
// methods which take varying writers and config states.
|
||||||
|
func fdump(cs *ConfigState, w io.Writer, a ...interface{}) {
|
||||||
|
for _, arg := range a {
|
||||||
|
if arg == nil {
|
||||||
|
w.Write(interfaceBytes)
|
||||||
|
w.Write(spaceBytes)
|
||||||
|
w.Write(nilAngleBytes)
|
||||||
|
w.Write(newlineBytes)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
d := dumpState{w: w, cs: cs}
|
||||||
|
d.pointers = make(map[uintptr]int)
|
||||||
|
d.dump(reflect.ValueOf(arg))
|
||||||
|
d.w.Write(newlineBytes)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fdump formats and displays the passed arguments to io.Writer w. It formats
|
||||||
|
// exactly the same as Dump.
|
||||||
|
func Fdump(w io.Writer, a ...interface{}) {
|
||||||
|
fdump(&Config, w, a...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sdump returns a string with the passed arguments formatted exactly the same
|
||||||
|
// as Dump.
|
||||||
|
func Sdump(a ...interface{}) string {
|
||||||
|
var buf bytes.Buffer
|
||||||
|
fdump(&Config, &buf, a...)
|
||||||
|
return buf.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
Dump displays the passed parameters to standard out with newlines, customizable
|
||||||
|
indentation, and additional debug information such as complete types and all
|
||||||
|
pointer addresses used to indirect to the final value. It provides the
|
||||||
|
following features over the built-in printing facilities provided by the fmt
|
||||||
|
package:
|
||||||
|
|
||||||
|
* Pointers are dereferenced and followed
|
||||||
|
* Circular data structures are detected and handled properly
|
||||||
|
* Custom Stringer/error interfaces are optionally invoked, including
|
||||||
|
on unexported types
|
||||||
|
* Custom types which only implement the Stringer/error interfaces via
|
||||||
|
a pointer receiver are optionally invoked when passing non-pointer
|
||||||
|
variables
|
||||||
|
* Byte arrays and slices are dumped like the hexdump -C command which
|
||||||
|
includes offsets, byte values in hex, and ASCII output
|
||||||
|
|
||||||
|
The configuration options are controlled by an exported package global,
|
||||||
|
spew.Config. See ConfigState for options documentation.
|
||||||
|
|
||||||
|
See Fdump if you would prefer dumping to an arbitrary io.Writer or Sdump to
|
||||||
|
get the formatted result as a string.
|
||||||
|
*/
|
||||||
|
func Dump(a ...interface{}) {
|
||||||
|
fdump(&Config, os.Stdout, a...)
|
||||||
|
}
|
||||||
+419
@@ -0,0 +1,419 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// supportedFlags is a list of all the character flags supported by fmt package.
|
||||||
|
const supportedFlags = "0-+# "
|
||||||
|
|
||||||
|
// formatState implements the fmt.Formatter interface and contains information
|
||||||
|
// about the state of a formatting operation. The NewFormatter function can
|
||||||
|
// be used to get a new Formatter which can be used directly as arguments
|
||||||
|
// in standard fmt package printing calls.
|
||||||
|
type formatState struct {
|
||||||
|
value interface{}
|
||||||
|
fs fmt.State
|
||||||
|
depth int
|
||||||
|
pointers map[uintptr]int
|
||||||
|
ignoreNextType bool
|
||||||
|
cs *ConfigState
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildDefaultFormat recreates the original format string without precision
|
||||||
|
// and width information to pass in to fmt.Sprintf in the case of an
|
||||||
|
// unrecognized type. Unless new types are added to the language, this
|
||||||
|
// function won't ever be called.
|
||||||
|
func (f *formatState) buildDefaultFormat() (format string) {
|
||||||
|
buf := bytes.NewBuffer(percentBytes)
|
||||||
|
|
||||||
|
for _, flag := range supportedFlags {
|
||||||
|
if f.fs.Flag(int(flag)) {
|
||||||
|
buf.WriteRune(flag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune('v')
|
||||||
|
|
||||||
|
format = buf.String()
|
||||||
|
return format
|
||||||
|
}
|
||||||
|
|
||||||
|
// constructOrigFormat recreates the original format string including precision
|
||||||
|
// and width information to pass along to the standard fmt package. This allows
|
||||||
|
// automatic deferral of all format strings this package doesn't support.
|
||||||
|
func (f *formatState) constructOrigFormat(verb rune) (format string) {
|
||||||
|
buf := bytes.NewBuffer(percentBytes)
|
||||||
|
|
||||||
|
for _, flag := range supportedFlags {
|
||||||
|
if f.fs.Flag(int(flag)) {
|
||||||
|
buf.WriteRune(flag)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if width, ok := f.fs.Width(); ok {
|
||||||
|
buf.WriteString(strconv.Itoa(width))
|
||||||
|
}
|
||||||
|
|
||||||
|
if precision, ok := f.fs.Precision(); ok {
|
||||||
|
buf.Write(precisionBytes)
|
||||||
|
buf.WriteString(strconv.Itoa(precision))
|
||||||
|
}
|
||||||
|
|
||||||
|
buf.WriteRune(verb)
|
||||||
|
|
||||||
|
format = buf.String()
|
||||||
|
return format
|
||||||
|
}
|
||||||
|
|
||||||
|
// unpackValue returns values inside of non-nil interfaces when possible and
|
||||||
|
// ensures that types for values which have been unpacked from an interface
|
||||||
|
// are displayed when the show types flag is also set.
|
||||||
|
// This is useful for data types like structs, arrays, slices, and maps which
|
||||||
|
// can contain varying types packed inside an interface.
|
||||||
|
func (f *formatState) unpackValue(v reflect.Value) reflect.Value {
|
||||||
|
if v.Kind() == reflect.Interface {
|
||||||
|
f.ignoreNextType = false
|
||||||
|
if !v.IsNil() {
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
|
||||||
|
// formatPtr handles formatting of pointers by indirecting them as necessary.
|
||||||
|
func (f *formatState) formatPtr(v reflect.Value) {
|
||||||
|
// Display nil if top level pointer is nil.
|
||||||
|
showTypes := f.fs.Flag('#')
|
||||||
|
if v.IsNil() && (!showTypes || f.ignoreNextType) {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remove pointers at or below the current depth from map used to detect
|
||||||
|
// circular refs.
|
||||||
|
for k, depth := range f.pointers {
|
||||||
|
if depth >= f.depth {
|
||||||
|
delete(f.pointers, k)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep list of all dereferenced pointers to possibly show later.
|
||||||
|
pointerChain := make([]uintptr, 0)
|
||||||
|
|
||||||
|
// Figure out how many levels of indirection there are by derferencing
|
||||||
|
// pointers and unpacking interfaces down the chain while detecting circular
|
||||||
|
// references.
|
||||||
|
nilFound := false
|
||||||
|
cycleFound := false
|
||||||
|
indirects := 0
|
||||||
|
ve := v
|
||||||
|
for ve.Kind() == reflect.Ptr {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
indirects++
|
||||||
|
addr := ve.Pointer()
|
||||||
|
pointerChain = append(pointerChain, addr)
|
||||||
|
if pd, ok := f.pointers[addr]; ok && pd < f.depth {
|
||||||
|
cycleFound = true
|
||||||
|
indirects--
|
||||||
|
break
|
||||||
|
}
|
||||||
|
f.pointers[addr] = f.depth
|
||||||
|
|
||||||
|
ve = ve.Elem()
|
||||||
|
if ve.Kind() == reflect.Interface {
|
||||||
|
if ve.IsNil() {
|
||||||
|
nilFound = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
ve = ve.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display type or indirection level depending on flags.
|
||||||
|
if showTypes && !f.ignoreNextType {
|
||||||
|
f.fs.Write(openParenBytes)
|
||||||
|
f.fs.Write(bytes.Repeat(asteriskBytes, indirects))
|
||||||
|
f.fs.Write([]byte(ve.Type().String()))
|
||||||
|
f.fs.Write(closeParenBytes)
|
||||||
|
} else {
|
||||||
|
if nilFound || cycleFound {
|
||||||
|
indirects += strings.Count(ve.Type().String(), "*")
|
||||||
|
}
|
||||||
|
f.fs.Write(openAngleBytes)
|
||||||
|
f.fs.Write([]byte(strings.Repeat("*", indirects)))
|
||||||
|
f.fs.Write(closeAngleBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display pointer information depending on flags.
|
||||||
|
if f.fs.Flag('+') && (len(pointerChain) > 0) {
|
||||||
|
f.fs.Write(openParenBytes)
|
||||||
|
for i, addr := range pointerChain {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(pointerChainBytes)
|
||||||
|
}
|
||||||
|
printHexPtr(f.fs, addr)
|
||||||
|
}
|
||||||
|
f.fs.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Display dereferenced value.
|
||||||
|
switch {
|
||||||
|
case nilFound == true:
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
|
||||||
|
case cycleFound == true:
|
||||||
|
f.fs.Write(circularShortBytes)
|
||||||
|
|
||||||
|
default:
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(ve)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// format is the main workhorse for providing the Formatter interface. It
|
||||||
|
// uses the passed reflect value to figure out what kind of object we are
|
||||||
|
// dealing with and formats it appropriately. It is a recursive function,
|
||||||
|
// however circular data structures are detected and handled properly.
|
||||||
|
func (f *formatState) format(v reflect.Value) {
|
||||||
|
// Handle invalid reflect values immediately.
|
||||||
|
kind := v.Kind()
|
||||||
|
if kind == reflect.Invalid {
|
||||||
|
f.fs.Write(invalidAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle pointers specially.
|
||||||
|
if kind == reflect.Ptr {
|
||||||
|
f.formatPtr(v)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print type information unless already handled elsewhere.
|
||||||
|
if !f.ignoreNextType && f.fs.Flag('#') {
|
||||||
|
f.fs.Write(openParenBytes)
|
||||||
|
f.fs.Write([]byte(v.Type().String()))
|
||||||
|
f.fs.Write(closeParenBytes)
|
||||||
|
}
|
||||||
|
f.ignoreNextType = false
|
||||||
|
|
||||||
|
// Call Stringer/error interfaces if they exist and the handle methods
|
||||||
|
// flag is enabled.
|
||||||
|
if !f.cs.DisableMethods {
|
||||||
|
if (kind != reflect.Invalid) && (kind != reflect.Interface) {
|
||||||
|
if handled := handleMethods(f.cs, f.fs, v); handled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch kind {
|
||||||
|
case reflect.Invalid:
|
||||||
|
// Do nothing. We should never get here since invalid has already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Bool:
|
||||||
|
printBool(f.fs, v.Bool())
|
||||||
|
|
||||||
|
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||||
|
printInt(f.fs, v.Int(), 10)
|
||||||
|
|
||||||
|
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||||
|
printUint(f.fs, v.Uint(), 10)
|
||||||
|
|
||||||
|
case reflect.Float32:
|
||||||
|
printFloat(f.fs, v.Float(), 32)
|
||||||
|
|
||||||
|
case reflect.Float64:
|
||||||
|
printFloat(f.fs, v.Float(), 64)
|
||||||
|
|
||||||
|
case reflect.Complex64:
|
||||||
|
printComplex(f.fs, v.Complex(), 32)
|
||||||
|
|
||||||
|
case reflect.Complex128:
|
||||||
|
printComplex(f.fs, v.Complex(), 64)
|
||||||
|
|
||||||
|
case reflect.Slice:
|
||||||
|
if v.IsNil() {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
fallthrough
|
||||||
|
|
||||||
|
case reflect.Array:
|
||||||
|
f.fs.Write(openBracketBytes)
|
||||||
|
f.depth++
|
||||||
|
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||||
|
f.fs.Write(maxShortBytes)
|
||||||
|
} else {
|
||||||
|
numEntries := v.Len()
|
||||||
|
for i := 0; i < numEntries; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(f.unpackValue(v.Index(i)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.depth--
|
||||||
|
f.fs.Write(closeBracketBytes)
|
||||||
|
|
||||||
|
case reflect.String:
|
||||||
|
f.fs.Write([]byte(v.String()))
|
||||||
|
|
||||||
|
case reflect.Interface:
|
||||||
|
// The only time we should get here is for nil interfaces due to
|
||||||
|
// unpackValue calls.
|
||||||
|
if v.IsNil() {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
case reflect.Ptr:
|
||||||
|
// Do nothing. We should never get here since pointers have already
|
||||||
|
// been handled above.
|
||||||
|
|
||||||
|
case reflect.Map:
|
||||||
|
// nil maps should be indicated as different than empty maps
|
||||||
|
if v.IsNil() {
|
||||||
|
f.fs.Write(nilAngleBytes)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
f.fs.Write(openMapBytes)
|
||||||
|
f.depth++
|
||||||
|
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||||
|
f.fs.Write(maxShortBytes)
|
||||||
|
} else {
|
||||||
|
keys := v.MapKeys()
|
||||||
|
if f.cs.SortKeys {
|
||||||
|
sortValues(keys, f.cs)
|
||||||
|
}
|
||||||
|
for i, key := range keys {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(f.unpackValue(key))
|
||||||
|
f.fs.Write(colonBytes)
|
||||||
|
f.ignoreNextType = true
|
||||||
|
f.format(f.unpackValue(v.MapIndex(key)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.depth--
|
||||||
|
f.fs.Write(closeMapBytes)
|
||||||
|
|
||||||
|
case reflect.Struct:
|
||||||
|
numFields := v.NumField()
|
||||||
|
f.fs.Write(openBraceBytes)
|
||||||
|
f.depth++
|
||||||
|
if (f.cs.MaxDepth != 0) && (f.depth > f.cs.MaxDepth) {
|
||||||
|
f.fs.Write(maxShortBytes)
|
||||||
|
} else {
|
||||||
|
vt := v.Type()
|
||||||
|
for i := 0; i < numFields; i++ {
|
||||||
|
if i > 0 {
|
||||||
|
f.fs.Write(spaceBytes)
|
||||||
|
}
|
||||||
|
vtf := vt.Field(i)
|
||||||
|
if f.fs.Flag('+') || f.fs.Flag('#') {
|
||||||
|
f.fs.Write([]byte(vtf.Name))
|
||||||
|
f.fs.Write(colonBytes)
|
||||||
|
}
|
||||||
|
f.format(f.unpackValue(v.Field(i)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.depth--
|
||||||
|
f.fs.Write(closeBraceBytes)
|
||||||
|
|
||||||
|
case reflect.Uintptr:
|
||||||
|
printHexPtr(f.fs, uintptr(v.Uint()))
|
||||||
|
|
||||||
|
case reflect.UnsafePointer, reflect.Chan, reflect.Func:
|
||||||
|
printHexPtr(f.fs, v.Pointer())
|
||||||
|
|
||||||
|
// There were not any other types at the time this code was written, but
|
||||||
|
// fall back to letting the default fmt package handle it if any get added.
|
||||||
|
default:
|
||||||
|
format := f.buildDefaultFormat()
|
||||||
|
if v.CanInterface() {
|
||||||
|
fmt.Fprintf(f.fs, format, v.Interface())
|
||||||
|
} else {
|
||||||
|
fmt.Fprintf(f.fs, format, v.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format satisfies the fmt.Formatter interface. See NewFormatter for usage
|
||||||
|
// details.
|
||||||
|
func (f *formatState) Format(fs fmt.State, verb rune) {
|
||||||
|
f.fs = fs
|
||||||
|
|
||||||
|
// Use standard formatting for verbs that are not v.
|
||||||
|
if verb != 'v' {
|
||||||
|
format := f.constructOrigFormat(verb)
|
||||||
|
fmt.Fprintf(fs, format, f.value)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.value == nil {
|
||||||
|
if fs.Flag('#') {
|
||||||
|
fs.Write(interfaceBytes)
|
||||||
|
}
|
||||||
|
fs.Write(nilAngleBytes)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.format(reflect.ValueOf(f.value))
|
||||||
|
}
|
||||||
|
|
||||||
|
// newFormatter is a helper function to consolidate the logic from the various
|
||||||
|
// public methods which take varying config states.
|
||||||
|
func newFormatter(cs *ConfigState, v interface{}) fmt.Formatter {
|
||||||
|
fs := &formatState{value: v, cs: cs}
|
||||||
|
fs.pointers = make(map[uintptr]int)
|
||||||
|
return fs
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
NewFormatter returns a custom formatter that satisfies the fmt.Formatter
|
||||||
|
interface. As a result, it integrates cleanly with standard fmt package
|
||||||
|
printing functions. The formatter is useful for inline printing of smaller data
|
||||||
|
types similar to the standard %v format specifier.
|
||||||
|
|
||||||
|
The custom formatter only responds to the %v (most compact), %+v (adds pointer
|
||||||
|
addresses), %#v (adds types), or %#+v (adds types and pointer addresses) verb
|
||||||
|
combinations. Any other verbs such as %x and %q will be sent to the the
|
||||||
|
standard fmt package for formatting. In addition, the custom formatter ignores
|
||||||
|
the width and precision arguments (however they will still work on the format
|
||||||
|
specifiers not handled by the custom formatter).
|
||||||
|
|
||||||
|
Typically this function shouldn't be called directly. It is much easier to make
|
||||||
|
use of the custom formatter by calling one of the convenience functions such as
|
||||||
|
Printf, Println, or Fprintf.
|
||||||
|
*/
|
||||||
|
func NewFormatter(v interface{}) fmt.Formatter {
|
||||||
|
return newFormatter(&Config, v)
|
||||||
|
}
|
||||||
+148
@@ -0,0 +1,148 @@
|
|||||||
|
/*
|
||||||
|
* Copyright (c) 2013-2016 Dave Collins <dave@davec.name>
|
||||||
|
*
|
||||||
|
* Permission to use, copy, modify, and distribute this software for any
|
||||||
|
* purpose with or without fee is hereby granted, provided that the above
|
||||||
|
* copyright notice and this permission notice appear in all copies.
|
||||||
|
*
|
||||||
|
* THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
||||||
|
* WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
||||||
|
* MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
||||||
|
* ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
||||||
|
* WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
||||||
|
* ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
||||||
|
* OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package spew
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Errorf is a wrapper for fmt.Errorf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the formatted string as a value that satisfies error. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Errorf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Errorf(format string, a ...interface{}) (err error) {
|
||||||
|
return fmt.Errorf(format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprint is a wrapper for fmt.Fprint that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprint(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Fprint(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprint(w, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintf is a wrapper for fmt.Fprintf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintf(w, format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Fprintf(w io.Writer, format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintf(w, format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fprintln is a wrapper for fmt.Fprintln that treats each argument as if it
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Fprintln(w, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Fprintln(w io.Writer, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Fprintln(w, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print is a wrapper for fmt.Print that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Print(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Print(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Print(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Printf is a wrapper for fmt.Printf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Printf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Printf(format string, a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Printf(format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Println is a wrapper for fmt.Println that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the number of bytes written and any write error encountered. See
|
||||||
|
// NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Println(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Println(a ...interface{}) (n int, err error) {
|
||||||
|
return fmt.Println(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprint is a wrapper for fmt.Sprint that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprint(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Sprint(a ...interface{}) string {
|
||||||
|
return fmt.Sprint(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintf is a wrapper for fmt.Sprintf that treats each argument as if it were
|
||||||
|
// passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintf(format, spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Sprintf(format string, a ...interface{}) string {
|
||||||
|
return fmt.Sprintf(format, convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sprintln is a wrapper for fmt.Sprintln that treats each argument as if it
|
||||||
|
// were passed with a default Formatter interface returned by NewFormatter. It
|
||||||
|
// returns the resulting string. See NewFormatter for formatting details.
|
||||||
|
//
|
||||||
|
// This function is shorthand for the following syntax:
|
||||||
|
//
|
||||||
|
// fmt.Sprintln(spew.NewFormatter(a), spew.NewFormatter(b))
|
||||||
|
func Sprintln(a ...interface{}) string {
|
||||||
|
return fmt.Sprintln(convertArgs(a)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertArgs accepts a slice of arguments and returns a slice of the same
|
||||||
|
// length with each argument converted to a default spew Formatter interface.
|
||||||
|
func convertArgs(args []interface{}) (formatters []interface{}) {
|
||||||
|
formatters = make([]interface{}, len(args))
|
||||||
|
for index, arg := range args {
|
||||||
|
formatters[index] = NewFormatter(arg)
|
||||||
|
}
|
||||||
|
return formatters
|
||||||
|
}
|
||||||
+87
@@ -0,0 +1,87 @@
|
|||||||
|
Eclipse Public License - v 1.0
|
||||||
|
|
||||||
|
THE ACCOMPANYING PROGRAM IS PROVIDED UNDER THE TERMS OF THIS ECLIPSE PUBLIC LICENSE ("AGREEMENT"). ANY USE, REPRODUCTION OR DISTRIBUTION OF THE PROGRAM CONSTITUTES RECIPIENT'S ACCEPTANCE OF THIS AGREEMENT.
|
||||||
|
|
||||||
|
1. DEFINITIONS
|
||||||
|
|
||||||
|
"Contribution" means:
|
||||||
|
|
||||||
|
a) in the case of the initial Contributor, the initial code and documentation distributed under this Agreement, and
|
||||||
|
|
||||||
|
b) in the case of each subsequent Contributor:
|
||||||
|
|
||||||
|
i) changes to the Program, and
|
||||||
|
|
||||||
|
ii) additions to the Program;
|
||||||
|
|
||||||
|
where such changes and/or additions to the Program originate from and are distributed by that particular Contributor. A Contribution 'originates' from a Contributor if it was added to the Program by such Contributor itself or anyone acting on such Contributor's behalf. Contributions do not include additions to the Program which: (i) are separate modules of software distributed in conjunction with the Program under their own license agreement, and (ii) are not derivative works of the Program.
|
||||||
|
|
||||||
|
"Contributor" means any person or entity that distributes the Program.
|
||||||
|
|
||||||
|
"Licensed Patents" mean patent claims licensable by a Contributor which are necessarily infringed by the use or sale of its Contribution alone or when combined with the Program.
|
||||||
|
|
||||||
|
"Program" means the Contributions distributed in accordance with this Agreement.
|
||||||
|
|
||||||
|
"Recipient" means anyone who receives the Program under this Agreement, including all Contributors.
|
||||||
|
|
||||||
|
2. GRANT OF RIGHTS
|
||||||
|
|
||||||
|
a) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free copyright license to reproduce, prepare derivative works of, publicly display, publicly perform, distribute and sublicense the Contribution of such Contributor, if any, and such derivative works, in source code and object code form.
|
||||||
|
|
||||||
|
b) Subject to the terms of this Agreement, each Contributor hereby grants Recipient a non-exclusive, worldwide, royalty-free patent license under Licensed Patents to make, use, sell, offer to sell, import and otherwise transfer the Contribution of such Contributor, if any, in source code and object code form. This patent license shall apply to the combination of the Contribution and the Program if, at the time the Contribution is added by the Contributor, such addition of the Contribution causes such combination to be covered by the Licensed Patents. The patent license shall not apply to any other combinations which include the Contribution. No hardware per se is licensed hereunder.
|
||||||
|
|
||||||
|
c) Recipient understands that although each Contributor grants the licenses to its Contributions set forth herein, no assurances are provided by any Contributor that the Program does not infringe the patent or other intellectual property rights of any other entity. Each Contributor disclaims any liability to Recipient for claims brought by any other entity based on infringement of intellectual property rights or otherwise. As a condition to exercising the rights and licenses granted hereunder, each Recipient hereby assumes sole responsibility to secure any other intellectual property rights needed, if any. For example, if a third party patent license is required to allow Recipient to distribute the Program, it is Recipient's responsibility to acquire that license before distributing the Program.
|
||||||
|
|
||||||
|
d) Each Contributor represents that to its knowledge it has sufficient copyright rights in its Contribution, if any, to grant the copyright license set forth in this Agreement.
|
||||||
|
|
||||||
|
3. REQUIREMENTS
|
||||||
|
|
||||||
|
A Contributor may choose to distribute the Program in object code form under its own license agreement, provided that:
|
||||||
|
|
||||||
|
a) it complies with the terms and conditions of this Agreement; and
|
||||||
|
|
||||||
|
b) its license agreement:
|
||||||
|
|
||||||
|
i) effectively disclaims on behalf of all Contributors all warranties and conditions, express and implied, including warranties or conditions of title and non-infringement, and implied warranties or conditions of merchantability and fitness for a particular purpose;
|
||||||
|
|
||||||
|
ii) effectively excludes on behalf of all Contributors all liability for damages, including direct, indirect, special, incidental and consequential damages, such as lost profits;
|
||||||
|
|
||||||
|
iii) states that any provisions which differ from this Agreement are offered by that Contributor alone and not by any other party; and
|
||||||
|
|
||||||
|
iv) states that source code for the Program is available from such Contributor, and informs licensees how to obtain it in a reasonable manner on or through a medium customarily used for software exchange.
|
||||||
|
|
||||||
|
When the Program is made available in source code form:
|
||||||
|
|
||||||
|
a) it must be made available under this Agreement; and
|
||||||
|
|
||||||
|
b) a copy of this Agreement must be included with each copy of the Program.
|
||||||
|
|
||||||
|
Contributors may not remove or alter any copyright notices contained within the Program.
|
||||||
|
|
||||||
|
Each Contributor must identify itself as the originator of its Contribution, if any, in a manner that reasonably allows subsequent Recipients to identify the originator of the Contribution.
|
||||||
|
|
||||||
|
4. COMMERCIAL DISTRIBUTION
|
||||||
|
|
||||||
|
Commercial distributors of software may accept certain responsibilities with respect to end users, business partners and the like. While this license is intended to facilitate the commercial use of the Program, the Contributor who includes the Program in a commercial product offering should do so in a manner which does not create potential liability for other Contributors. Therefore, if a Contributor includes the Program in a commercial product offering, such Contributor ("Commercial Contributor") hereby agrees to defend and indemnify every other Contributor ("Indemnified Contributor") against any losses, damages and costs (collectively "Losses") arising from claims, lawsuits and other legal actions brought by a third party against the Indemnified Contributor to the extent caused by the acts or omissions of such Commercial Contributor in connection with its distribution of the Program in a commercial product offering. The obligations in this section do not apply to any claims or Losses relating to any actual or alleged intellectual property infringement. In order to qualify, an Indemnified Contributor must: a) promptly notify the Commercial Contributor in writing of such claim, and b) allow the Commercial Contributor to control, and cooperate with the Commercial Contributor in, the defense and any related settlement negotiations. The Indemnified Contributor may participate in any such claim at its own expense.
|
||||||
|
|
||||||
|
For example, a Contributor might include the Program in a commercial product offering, Product X. That Contributor is then a Commercial Contributor. If that Commercial Contributor then makes performance claims, or offers warranties related to Product X, those performance claims and warranties are such Commercial Contributor's responsibility alone. Under this section, the Commercial Contributor would have to defend claims against the other Contributors related to those performance claims and warranties, and if a court requires any other Contributor to pay any damages as a result, the Commercial Contributor must pay those damages.
|
||||||
|
|
||||||
|
5. NO WARRANTY
|
||||||
|
|
||||||
|
EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, THE PROGRAM IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OR CONDITIONS OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE. Each Recipient is solely responsible for determining the appropriateness of using and distributing the Program and assumes all risks associated with its exercise of rights under this Agreement , including but not limited to the risks and costs of program errors, compliance with applicable laws, damage to or loss of data, programs or equipment, and unavailability or interruption of operations.
|
||||||
|
|
||||||
|
6. DISCLAIMER OF LIABILITY
|
||||||
|
|
||||||
|
EXCEPT AS EXPRESSLY SET FORTH IN THIS AGREEMENT, NEITHER RECIPIENT NOR ANY CONTRIBUTORS SHALL HAVE ANY LIABILITY FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING WITHOUT LIMITATION LOST PROFITS), HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OR DISTRIBUTION OF THE PROGRAM OR THE EXERCISE OF ANY RIGHTS GRANTED HEREUNDER, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
|
||||||
|
|
||||||
|
7. GENERAL
|
||||||
|
|
||||||
|
If any provision of this Agreement is invalid or unenforceable under applicable law, it shall not affect the validity or enforceability of the remainder of the terms of this Agreement, and without further action by the parties hereto, such provision shall be reformed to the minimum extent necessary to make such provision valid and enforceable.
|
||||||
|
|
||||||
|
If Recipient institutes patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Program itself (excluding combinations of the Program with other software or hardware) infringes such Recipient's patent(s), then such Recipient's rights granted under Section 2(b) shall terminate as of the date such litigation is filed.
|
||||||
|
|
||||||
|
All Recipient's rights under this Agreement shall terminate if it fails to comply with any of the material terms or conditions of this Agreement and does not cure such failure in a reasonable period of time after becoming aware of such noncompliance. If all Recipient's rights under this Agreement terminate, Recipient agrees to cease use and distribution of the Program as soon as reasonably practicable. However, Recipient's obligations under this Agreement and any licenses granted by Recipient relating to the Program shall continue and survive.
|
||||||
|
|
||||||
|
Everyone is permitted to copy and distribute copies of this Agreement, but in order to avoid inconsistency the Agreement is copyrighted and may only be modified in the following manner. The Agreement Steward reserves the right to publish new versions (including revisions) of this Agreement from time to time. No one other than the Agreement Steward has the right to modify this Agreement. The Eclipse Foundation is the initial Agreement Steward. The Eclipse Foundation may assign the responsibility to serve as the Agreement Steward to a suitable separate entity. Each new version of the Agreement will be given a distinguishing version number. The Program (including Contributions) may always be distributed subject to the version of the Agreement under which it was received. In addition, after a new version of the Agreement is published, Contributor may elect to distribute the Program (including its Contributions) under the new version. Except as expressly stated in Sections 2(a) and 2(b) above, Recipient receives no rights or licenses to the intellectual property of any Contributor under this Agreement, whether expressly, by implication, estoppel or otherwise. All rights in the Program not expressly granted under this Agreement are reserved.
|
||||||
|
|
||||||
|
This Agreement is governed by the laws of the State of New York and the intellectual property laws of the United States of America. No party to this Agreement will bring a legal action under this Agreement more than one year after the cause of action arose. Each party waives its rights to a jury trial in any resulting litigation.
|
||||||
+55
@@ -0,0 +1,55 @@
|
|||||||
|
package packets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
//ConnackPacket is an internal representation of the fields of the
|
||||||
|
//Connack MQTT packet
|
||||||
|
type ConnackPacket struct {
|
||||||
|
FixedHeader
|
||||||
|
SessionPresent bool
|
||||||
|
ReturnCode byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ca *ConnackPacket) String() string {
|
||||||
|
str := fmt.Sprintf("%s", ca.FixedHeader)
|
||||||
|
str += " "
|
||||||
|
str += fmt.Sprintf("sessionpresent: %t returncode: %d", ca.SessionPresent, ca.ReturnCode)
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ca *ConnackPacket) Write(w io.Writer) error {
|
||||||
|
var body bytes.Buffer
|
||||||
|
var err error
|
||||||
|
|
||||||
|
body.WriteByte(boolToByte(ca.SessionPresent))
|
||||||
|
body.WriteByte(ca.ReturnCode)
|
||||||
|
ca.FixedHeader.RemainingLength = 2
|
||||||
|
packet := ca.FixedHeader.pack()
|
||||||
|
packet.Write(body.Bytes())
|
||||||
|
_, err = packet.WriteTo(w)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Unpack decodes the details of a ControlPacket after the fixed
|
||||||
|
//header has been read
|
||||||
|
func (ca *ConnackPacket) Unpack(b io.Reader) error {
|
||||||
|
flags, err := decodeByte(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
ca.SessionPresent = 1&flags > 0
|
||||||
|
ca.ReturnCode, err = decodeByte(b)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Details returns a Details struct containing the Qos and
|
||||||
|
//MessageID of this ControlPacket
|
||||||
|
func (ca *ConnackPacket) Details() Details {
|
||||||
|
return Details{Qos: 0, MessageID: 0}
|
||||||
|
}
|
||||||
+154
@@ -0,0 +1,154 @@
|
|||||||
|
package packets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
//ConnectPacket is an internal representation of the fields of the
|
||||||
|
//Connect MQTT packet
|
||||||
|
type ConnectPacket struct {
|
||||||
|
FixedHeader
|
||||||
|
ProtocolName string
|
||||||
|
ProtocolVersion byte
|
||||||
|
CleanSession bool
|
||||||
|
WillFlag bool
|
||||||
|
WillQos byte
|
||||||
|
WillRetain bool
|
||||||
|
UsernameFlag bool
|
||||||
|
PasswordFlag bool
|
||||||
|
ReservedBit byte
|
||||||
|
Keepalive uint16
|
||||||
|
|
||||||
|
ClientIdentifier string
|
||||||
|
WillTopic string
|
||||||
|
WillMessage []byte
|
||||||
|
Username string
|
||||||
|
Password []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnectPacket) String() string {
|
||||||
|
str := fmt.Sprintf("%s", c.FixedHeader)
|
||||||
|
str += " "
|
||||||
|
str += fmt.Sprintf("protocolversion: %d protocolname: %s cleansession: %t willflag: %t WillQos: %d WillRetain: %t Usernameflag: %t Passwordflag: %t keepalive: %d clientId: %s willtopic: %s willmessage: %s Username: %s Password: %s", c.ProtocolVersion, c.ProtocolName, c.CleanSession, c.WillFlag, c.WillQos, c.WillRetain, c.UsernameFlag, c.PasswordFlag, c.Keepalive, c.ClientIdentifier, c.WillTopic, c.WillMessage, c.Username, c.Password)
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ConnectPacket) Write(w io.Writer) error {
|
||||||
|
var body bytes.Buffer
|
||||||
|
var err error
|
||||||
|
|
||||||
|
body.Write(encodeString(c.ProtocolName))
|
||||||
|
body.WriteByte(c.ProtocolVersion)
|
||||||
|
body.WriteByte(boolToByte(c.CleanSession)<<1 | boolToByte(c.WillFlag)<<2 | c.WillQos<<3 | boolToByte(c.WillRetain)<<5 | boolToByte(c.PasswordFlag)<<6 | boolToByte(c.UsernameFlag)<<7)
|
||||||
|
body.Write(encodeUint16(c.Keepalive))
|
||||||
|
body.Write(encodeString(c.ClientIdentifier))
|
||||||
|
if c.WillFlag {
|
||||||
|
body.Write(encodeString(c.WillTopic))
|
||||||
|
body.Write(encodeBytes(c.WillMessage))
|
||||||
|
}
|
||||||
|
if c.UsernameFlag {
|
||||||
|
body.Write(encodeString(c.Username))
|
||||||
|
}
|
||||||
|
if c.PasswordFlag {
|
||||||
|
body.Write(encodeBytes(c.Password))
|
||||||
|
}
|
||||||
|
c.FixedHeader.RemainingLength = body.Len()
|
||||||
|
packet := c.FixedHeader.pack()
|
||||||
|
packet.Write(body.Bytes())
|
||||||
|
_, err = packet.WriteTo(w)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Unpack decodes the details of a ControlPacket after the fixed
|
||||||
|
//header has been read
|
||||||
|
func (c *ConnectPacket) Unpack(b io.Reader) error {
|
||||||
|
var err error
|
||||||
|
c.ProtocolName, err = decodeString(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.ProtocolVersion, err = decodeByte(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
options, err := decodeByte(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.ReservedBit = 1 & options
|
||||||
|
c.CleanSession = 1&(options>>1) > 0
|
||||||
|
c.WillFlag = 1&(options>>2) > 0
|
||||||
|
c.WillQos = 3 & (options >> 3)
|
||||||
|
c.WillRetain = 1&(options>>5) > 0
|
||||||
|
c.PasswordFlag = 1&(options>>6) > 0
|
||||||
|
c.UsernameFlag = 1&(options>>7) > 0
|
||||||
|
c.Keepalive, err = decodeUint16(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.ClientIdentifier, err = decodeString(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if c.WillFlag {
|
||||||
|
c.WillTopic, err = decodeString(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.WillMessage, err = decodeBytes(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.UsernameFlag {
|
||||||
|
c.Username, err = decodeString(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if c.PasswordFlag {
|
||||||
|
c.Password, err = decodeBytes(b)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//Validate performs validation of the fields of a Connect packet
|
||||||
|
func (c *ConnectPacket) Validate() byte {
|
||||||
|
if c.PasswordFlag && !c.UsernameFlag {
|
||||||
|
return ErrRefusedBadUsernameOrPassword
|
||||||
|
}
|
||||||
|
if c.ReservedBit != 0 {
|
||||||
|
//Bad reserved bit
|
||||||
|
return ErrProtocolViolation
|
||||||
|
}
|
||||||
|
if (c.ProtocolName == "MQIsdp" && c.ProtocolVersion != 3) || (c.ProtocolName == "MQTT" && c.ProtocolVersion != 4) {
|
||||||
|
//Mismatched or unsupported protocol version
|
||||||
|
return ErrRefusedBadProtocolVersion
|
||||||
|
}
|
||||||
|
if c.ProtocolName != "MQIsdp" && c.ProtocolName != "MQTT" {
|
||||||
|
//Bad protocol name
|
||||||
|
return ErrProtocolViolation
|
||||||
|
}
|
||||||
|
if len(c.ClientIdentifier) > 65535 || len(c.Username) > 65535 || len(c.Password) > 65535 {
|
||||||
|
//Bad size field
|
||||||
|
return ErrProtocolViolation
|
||||||
|
}
|
||||||
|
if len(c.ClientIdentifier) == 0 && !c.CleanSession {
|
||||||
|
//Bad client identifier
|
||||||
|
return ErrRefusedIDRejected
|
||||||
|
}
|
||||||
|
return Accepted
|
||||||
|
}
|
||||||
|
|
||||||
|
//Details returns a Details struct containing the Qos and
|
||||||
|
//MessageID of this ControlPacket
|
||||||
|
func (c *ConnectPacket) Details() Details {
|
||||||
|
return Details{Qos: 0, MessageID: 0}
|
||||||
|
}
|
||||||
+36
@@ -0,0 +1,36 @@
|
|||||||
|
package packets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
//DisconnectPacket is an internal representation of the fields of the
|
||||||
|
//Disconnect MQTT packet
|
||||||
|
type DisconnectPacket struct {
|
||||||
|
FixedHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DisconnectPacket) String() string {
|
||||||
|
str := fmt.Sprintf("%s", d.FixedHeader)
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *DisconnectPacket) Write(w io.Writer) error {
|
||||||
|
packet := d.FixedHeader.pack()
|
||||||
|
_, err := packet.WriteTo(w)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Unpack decodes the details of a ControlPacket after the fixed
|
||||||
|
//header has been read
|
||||||
|
func (d *DisconnectPacket) Unpack(b io.Reader) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//Details returns a Details struct containing the Qos and
|
||||||
|
//MessageID of this ControlPacket
|
||||||
|
func (d *DisconnectPacket) Details() Details {
|
||||||
|
return Details{Qos: 0, MessageID: 0}
|
||||||
|
}
|
||||||
+346
@@ -0,0 +1,346 @@
|
|||||||
|
package packets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
//ControlPacket defines the interface for structs intended to hold
|
||||||
|
//decoded MQTT packets, either from being read or before being
|
||||||
|
//written
|
||||||
|
type ControlPacket interface {
|
||||||
|
Write(io.Writer) error
|
||||||
|
Unpack(io.Reader) error
|
||||||
|
String() string
|
||||||
|
Details() Details
|
||||||
|
}
|
||||||
|
|
||||||
|
//PacketNames maps the constants for each of the MQTT packet types
|
||||||
|
//to a string representation of their name.
|
||||||
|
var PacketNames = map[uint8]string{
|
||||||
|
1: "CONNECT",
|
||||||
|
2: "CONNACK",
|
||||||
|
3: "PUBLISH",
|
||||||
|
4: "PUBACK",
|
||||||
|
5: "PUBREC",
|
||||||
|
6: "PUBREL",
|
||||||
|
7: "PUBCOMP",
|
||||||
|
8: "SUBSCRIBE",
|
||||||
|
9: "SUBACK",
|
||||||
|
10: "UNSUBSCRIBE",
|
||||||
|
11: "UNSUBACK",
|
||||||
|
12: "PINGREQ",
|
||||||
|
13: "PINGRESP",
|
||||||
|
14: "DISCONNECT",
|
||||||
|
}
|
||||||
|
|
||||||
|
//Below are the constants assigned to each of the MQTT packet types
|
||||||
|
const (
|
||||||
|
Connect = 1
|
||||||
|
Connack = 2
|
||||||
|
Publish = 3
|
||||||
|
Puback = 4
|
||||||
|
Pubrec = 5
|
||||||
|
Pubrel = 6
|
||||||
|
Pubcomp = 7
|
||||||
|
Subscribe = 8
|
||||||
|
Suback = 9
|
||||||
|
Unsubscribe = 10
|
||||||
|
Unsuback = 11
|
||||||
|
Pingreq = 12
|
||||||
|
Pingresp = 13
|
||||||
|
Disconnect = 14
|
||||||
|
)
|
||||||
|
|
||||||
|
//Below are the const definitions for error codes returned by
|
||||||
|
//Connect()
|
||||||
|
const (
|
||||||
|
Accepted = 0x00
|
||||||
|
ErrRefusedBadProtocolVersion = 0x01
|
||||||
|
ErrRefusedIDRejected = 0x02
|
||||||
|
ErrRefusedServerUnavailable = 0x03
|
||||||
|
ErrRefusedBadUsernameOrPassword = 0x04
|
||||||
|
ErrRefusedNotAuthorised = 0x05
|
||||||
|
ErrNetworkError = 0xFE
|
||||||
|
ErrProtocolViolation = 0xFF
|
||||||
|
)
|
||||||
|
|
||||||
|
//ConnackReturnCodes is a map of the error codes constants for Connect()
|
||||||
|
//to a string representation of the error
|
||||||
|
var ConnackReturnCodes = map[uint8]string{
|
||||||
|
0: "Connection Accepted",
|
||||||
|
1: "Connection Refused: Bad Protocol Version",
|
||||||
|
2: "Connection Refused: Client Identifier Rejected",
|
||||||
|
3: "Connection Refused: Server Unavailable",
|
||||||
|
4: "Connection Refused: Username or Password in unknown format",
|
||||||
|
5: "Connection Refused: Not Authorised",
|
||||||
|
254: "Connection Error",
|
||||||
|
255: "Connection Refused: Protocol Violation",
|
||||||
|
}
|
||||||
|
|
||||||
|
//ConnErrors is a map of the errors codes constants for Connect()
|
||||||
|
//to a Go error
|
||||||
|
var ConnErrors = map[byte]error{
|
||||||
|
Accepted: nil,
|
||||||
|
ErrRefusedBadProtocolVersion: errors.New("Unnacceptable protocol version"),
|
||||||
|
ErrRefusedIDRejected: errors.New("Identifier rejected"),
|
||||||
|
ErrRefusedServerUnavailable: errors.New("Server Unavailable"),
|
||||||
|
ErrRefusedBadUsernameOrPassword: errors.New("Bad user name or password"),
|
||||||
|
ErrRefusedNotAuthorised: errors.New("Not Authorized"),
|
||||||
|
ErrNetworkError: errors.New("Network Error"),
|
||||||
|
ErrProtocolViolation: errors.New("Protocol Violation"),
|
||||||
|
}
|
||||||
|
|
||||||
|
//ReadPacket takes an instance of an io.Reader (such as net.Conn) and attempts
|
||||||
|
//to read an MQTT packet from the stream. It returns a ControlPacket
|
||||||
|
//representing the decoded MQTT packet and an error. One of these returns will
|
||||||
|
//always be nil, a nil ControlPacket indicating an error occurred.
|
||||||
|
func ReadPacket(r io.Reader) (ControlPacket, error) {
|
||||||
|
var fh FixedHeader
|
||||||
|
b := make([]byte, 1)
|
||||||
|
|
||||||
|
_, err := io.ReadFull(r, b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
err = fh.unpack(b[0], r)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
cp, err := NewControlPacketWithHeader(fh)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
packetBytes := make([]byte, fh.RemainingLength)
|
||||||
|
n, err := io.ReadFull(r, packetBytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if n != fh.RemainingLength {
|
||||||
|
return nil, errors.New("Failed to read expected data")
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cp.Unpack(bytes.NewBuffer(packetBytes))
|
||||||
|
return cp, err
|
||||||
|
}
|
||||||
|
|
||||||
|
//NewControlPacket is used to create a new ControlPacket of the type specified
|
||||||
|
//by packetType, this is usually done by reference to the packet type constants
|
||||||
|
//defined in packets.go. The newly created ControlPacket is empty and a pointer
|
||||||
|
//is returned.
|
||||||
|
func NewControlPacket(packetType byte) ControlPacket {
|
||||||
|
switch packetType {
|
||||||
|
case Connect:
|
||||||
|
return &ConnectPacket{FixedHeader: FixedHeader{MessageType: Connect}}
|
||||||
|
case Connack:
|
||||||
|
return &ConnackPacket{FixedHeader: FixedHeader{MessageType: Connack}}
|
||||||
|
case Disconnect:
|
||||||
|
return &DisconnectPacket{FixedHeader: FixedHeader{MessageType: Disconnect}}
|
||||||
|
case Publish:
|
||||||
|
return &PublishPacket{FixedHeader: FixedHeader{MessageType: Publish}}
|
||||||
|
case Puback:
|
||||||
|
return &PubackPacket{FixedHeader: FixedHeader{MessageType: Puback}}
|
||||||
|
case Pubrec:
|
||||||
|
return &PubrecPacket{FixedHeader: FixedHeader{MessageType: Pubrec}}
|
||||||
|
case Pubrel:
|
||||||
|
return &PubrelPacket{FixedHeader: FixedHeader{MessageType: Pubrel, Qos: 1}}
|
||||||
|
case Pubcomp:
|
||||||
|
return &PubcompPacket{FixedHeader: FixedHeader{MessageType: Pubcomp}}
|
||||||
|
case Subscribe:
|
||||||
|
return &SubscribePacket{FixedHeader: FixedHeader{MessageType: Subscribe, Qos: 1}}
|
||||||
|
case Suback:
|
||||||
|
return &SubackPacket{FixedHeader: FixedHeader{MessageType: Suback}}
|
||||||
|
case Unsubscribe:
|
||||||
|
return &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: Unsubscribe, Qos: 1}}
|
||||||
|
case Unsuback:
|
||||||
|
return &UnsubackPacket{FixedHeader: FixedHeader{MessageType: Unsuback}}
|
||||||
|
case Pingreq:
|
||||||
|
return &PingreqPacket{FixedHeader: FixedHeader{MessageType: Pingreq}}
|
||||||
|
case Pingresp:
|
||||||
|
return &PingrespPacket{FixedHeader: FixedHeader{MessageType: Pingresp}}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//NewControlPacketWithHeader is used to create a new ControlPacket of the type
|
||||||
|
//specified within the FixedHeader that is passed to the function.
|
||||||
|
//The newly created ControlPacket is empty and a pointer is returned.
|
||||||
|
func NewControlPacketWithHeader(fh FixedHeader) (ControlPacket, error) {
|
||||||
|
switch fh.MessageType {
|
||||||
|
case Connect:
|
||||||
|
return &ConnectPacket{FixedHeader: fh}, nil
|
||||||
|
case Connack:
|
||||||
|
return &ConnackPacket{FixedHeader: fh}, nil
|
||||||
|
case Disconnect:
|
||||||
|
return &DisconnectPacket{FixedHeader: fh}, nil
|
||||||
|
case Publish:
|
||||||
|
return &PublishPacket{FixedHeader: fh}, nil
|
||||||
|
case Puback:
|
||||||
|
return &PubackPacket{FixedHeader: fh}, nil
|
||||||
|
case Pubrec:
|
||||||
|
return &PubrecPacket{FixedHeader: fh}, nil
|
||||||
|
case Pubrel:
|
||||||
|
return &PubrelPacket{FixedHeader: fh}, nil
|
||||||
|
case Pubcomp:
|
||||||
|
return &PubcompPacket{FixedHeader: fh}, nil
|
||||||
|
case Subscribe:
|
||||||
|
return &SubscribePacket{FixedHeader: fh}, nil
|
||||||
|
case Suback:
|
||||||
|
return &SubackPacket{FixedHeader: fh}, nil
|
||||||
|
case Unsubscribe:
|
||||||
|
return &UnsubscribePacket{FixedHeader: fh}, nil
|
||||||
|
case Unsuback:
|
||||||
|
return &UnsubackPacket{FixedHeader: fh}, nil
|
||||||
|
case Pingreq:
|
||||||
|
return &PingreqPacket{FixedHeader: fh}, nil
|
||||||
|
case Pingresp:
|
||||||
|
return &PingrespPacket{FixedHeader: fh}, nil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("unsupported packet type 0x%x", fh.MessageType)
|
||||||
|
}
|
||||||
|
|
||||||
|
//Details struct returned by the Details() function called on
|
||||||
|
//ControlPackets to present details of the Qos and MessageID
|
||||||
|
//of the ControlPacket
|
||||||
|
type Details struct {
|
||||||
|
Qos byte
|
||||||
|
MessageID uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
//FixedHeader is a struct to hold the decoded information from
|
||||||
|
//the fixed header of an MQTT ControlPacket
|
||||||
|
type FixedHeader struct {
|
||||||
|
MessageType byte
|
||||||
|
Dup bool
|
||||||
|
Qos byte
|
||||||
|
Retain bool
|
||||||
|
RemainingLength int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fh FixedHeader) String() string {
|
||||||
|
return fmt.Sprintf("%s: dup: %t qos: %d retain: %t rLength: %d", PacketNames[fh.MessageType], fh.Dup, fh.Qos, fh.Retain, fh.RemainingLength)
|
||||||
|
}
|
||||||
|
|
||||||
|
func boolToByte(b bool) byte {
|
||||||
|
switch b {
|
||||||
|
case true:
|
||||||
|
return 1
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fh *FixedHeader) pack() bytes.Buffer {
|
||||||
|
var header bytes.Buffer
|
||||||
|
header.WriteByte(fh.MessageType<<4 | boolToByte(fh.Dup)<<3 | fh.Qos<<1 | boolToByte(fh.Retain))
|
||||||
|
header.Write(encodeLength(fh.RemainingLength))
|
||||||
|
return header
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fh *FixedHeader) unpack(typeAndFlags byte, r io.Reader) error {
|
||||||
|
fh.MessageType = typeAndFlags >> 4
|
||||||
|
fh.Dup = (typeAndFlags>>3)&0x01 > 0
|
||||||
|
fh.Qos = (typeAndFlags >> 1) & 0x03
|
||||||
|
fh.Retain = typeAndFlags&0x01 > 0
|
||||||
|
|
||||||
|
var err error
|
||||||
|
fh.RemainingLength, err = decodeLength(r)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeByte(b io.Reader) (byte, error) {
|
||||||
|
num := make([]byte, 1)
|
||||||
|
_, err := b.Read(num)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return num[0], nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeUint16(b io.Reader) (uint16, error) {
|
||||||
|
num := make([]byte, 2)
|
||||||
|
_, err := b.Read(num)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return binary.BigEndian.Uint16(num), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeUint16(num uint16) []byte {
|
||||||
|
bytes := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(bytes, num)
|
||||||
|
return bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeString(field string) []byte {
|
||||||
|
return encodeBytes([]byte(field))
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeString(b io.Reader) (string, error) {
|
||||||
|
buf, err := decodeBytes(b)
|
||||||
|
return string(buf), err
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeBytes(b io.Reader) ([]byte, error) {
|
||||||
|
fieldLength, err := decodeUint16(b)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
field := make([]byte, fieldLength)
|
||||||
|
_, err = b.Read(field)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return field, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeBytes(field []byte) []byte {
|
||||||
|
fieldLength := make([]byte, 2)
|
||||||
|
binary.BigEndian.PutUint16(fieldLength, uint16(len(field)))
|
||||||
|
return append(fieldLength, field...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeLength(length int) []byte {
|
||||||
|
var encLength []byte
|
||||||
|
for {
|
||||||
|
digit := byte(length % 128)
|
||||||
|
length /= 128
|
||||||
|
if length > 0 {
|
||||||
|
digit |= 0x80
|
||||||
|
}
|
||||||
|
encLength = append(encLength, digit)
|
||||||
|
if length == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return encLength
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeLength(r io.Reader) (int, error) {
|
||||||
|
var rLength uint32
|
||||||
|
var multiplier uint32
|
||||||
|
b := make([]byte, 1)
|
||||||
|
for multiplier < 27 { //fix: Infinite '(digit & 128) == 1' will cause the dead loop
|
||||||
|
_, err := io.ReadFull(r, b)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
digit := b[0]
|
||||||
|
rLength |= uint32(digit&127) << multiplier
|
||||||
|
if (digit & 128) == 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
multiplier += 7
|
||||||
|
}
|
||||||
|
return int(rLength), nil
|
||||||
|
}
|
||||||
+36
@@ -0,0 +1,36 @@
|
|||||||
|
package packets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
//PingreqPacket is an internal representation of the fields of the
|
||||||
|
//Pingreq MQTT packet
|
||||||
|
type PingreqPacket struct {
|
||||||
|
FixedHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pr *PingreqPacket) String() string {
|
||||||
|
str := fmt.Sprintf("%s", pr.FixedHeader)
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pr *PingreqPacket) Write(w io.Writer) error {
|
||||||
|
packet := pr.FixedHeader.pack()
|
||||||
|
_, err := packet.WriteTo(w)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Unpack decodes the details of a ControlPacket after the fixed
|
||||||
|
//header has been read
|
||||||
|
func (pr *PingreqPacket) Unpack(b io.Reader) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//Details returns a Details struct containing the Qos and
|
||||||
|
//MessageID of this ControlPacket
|
||||||
|
func (pr *PingreqPacket) Details() Details {
|
||||||
|
return Details{Qos: 0, MessageID: 0}
|
||||||
|
}
|
||||||
+36
@@ -0,0 +1,36 @@
|
|||||||
|
package packets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
//PingrespPacket is an internal representation of the fields of the
|
||||||
|
//Pingresp MQTT packet
|
||||||
|
type PingrespPacket struct {
|
||||||
|
FixedHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pr *PingrespPacket) String() string {
|
||||||
|
str := fmt.Sprintf("%s", pr.FixedHeader)
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pr *PingrespPacket) Write(w io.Writer) error {
|
||||||
|
packet := pr.FixedHeader.pack()
|
||||||
|
_, err := packet.WriteTo(w)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Unpack decodes the details of a ControlPacket after the fixed
|
||||||
|
//header has been read
|
||||||
|
func (pr *PingrespPacket) Unpack(b io.Reader) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//Details returns a Details struct containing the Qos and
|
||||||
|
//MessageID of this ControlPacket
|
||||||
|
func (pr *PingrespPacket) Details() Details {
|
||||||
|
return Details{Qos: 0, MessageID: 0}
|
||||||
|
}
|
||||||
+45
@@ -0,0 +1,45 @@
|
|||||||
|
package packets
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
//PubackPacket is an internal representation of the fields of the
|
||||||
|
//Puback MQTT packet
|
||||||
|
type PubackPacket struct {
|
||||||
|
FixedHeader
|
||||||
|
MessageID uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pa *PubackPacket) String() string {
|
||||||
|
str := fmt.Sprintf("%s", pa.FixedHeader)
|
||||||
|
str += " "
|
||||||
|
str += fmt.Sprintf("MessageID: %d", pa.MessageID)
|
||||||
|
return str
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pa *PubackPacket) Write(w io.Writer) error {
|
||||||
|
var err error
|
||||||
|
pa.FixedHeader.RemainingLength = 2
|
||||||
|
packet := pa.FixedHeader.pack()
|
||||||
|
packet.Write(encodeUint16(pa.MessageID))
|
||||||
|
_, err = packet.WriteTo(w)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Unpack decodes the details of a ControlPacket after the fixed
|
||||||
|
//header has been read
|
||||||
|
func (pa *PubackPacket) Unpack(b io.Reader) error {
|
||||||
|
var err error
|
||||||
|
pa.MessageID, err = decodeUint16(b)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
//Details returns a Details struct containing the Qos and
|
||||||
|
//MessageID of this ControlPacket
|
||||||
|
func (pa *PubackPacket) Details() Details {
|
||||||
|
return Details{Qos: pa.Qos, MessageID: pa.MessageID}
|
||||||
|
}
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user