mirror of
https://github.com/fhmq/hmq.git
synced 2026-05-04 07:08:32 +00:00
Compare commits
90 Commits
packet
...
plugin_upd
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69a26f8cd9 | ||
|
|
148738800b | ||
|
|
e4e736d1e2 | ||
|
|
4c5a48a44b | ||
|
|
c6b1f1db42 | ||
|
|
daf4a0e0f5 | ||
|
|
c350d16ca1 | ||
|
|
edc46c1ee6 | ||
|
|
6193be74fa | ||
|
|
90beada459 | ||
|
|
6c7fe6a0f7 | ||
|
|
2b56664d85 | ||
|
|
7547ad3bdc | ||
|
|
84e7fe2490 | ||
|
|
684584b208 | ||
|
|
56fb4a2d54 | ||
|
|
5ed4728575 | ||
|
|
c0fea6a5ba | ||
|
|
47500910e1 | ||
|
|
0ff20b6ee2 | ||
|
|
7155667f6c | ||
|
|
83db82cdcc | ||
|
|
b3653bcfb1 | ||
|
|
221d00480e | ||
|
|
91733bf91e | ||
|
|
ef252550dc | ||
|
|
1058256235 | ||
|
|
5a569f14a3 | ||
|
|
93b21777ff | ||
|
|
dcf2934e1b | ||
|
|
d9e6e216b0 | ||
|
|
ca3951769a | ||
|
|
0439e7ce90 | ||
|
|
dc0f2185ab | ||
|
|
7462afcfb5 | ||
|
|
114e6f901e | ||
|
|
0cb51bd37a | ||
|
|
819b4725f2 | ||
|
|
85bdeccbfc | ||
|
|
1339a04b28 | ||
|
|
957329d85c | ||
|
|
7db7edaa17 | ||
|
|
1d6f6a4a71 | ||
|
|
123bb7210f | ||
|
|
9ad6590e83 | ||
|
|
516db49db5 | ||
|
|
a260057bfe | ||
|
|
bdd802ebfb | ||
|
|
5786e69b01 | ||
|
|
6a89b627d4 | ||
|
|
208a7cf0a8 | ||
|
|
a7fb7f1912 | ||
|
|
eeab0c6b7d | ||
|
|
4646042b7f | ||
|
|
49385e52fd | ||
|
|
3ed8625bb9 | ||
|
|
6b50060eae | ||
|
|
96277996f0 | ||
|
|
5601632a33 | ||
|
|
cc1b3239ad | ||
|
|
476d22568b | ||
|
|
c85ba76f8f | ||
|
|
f3b2924b07 | ||
|
|
6144aeb6bf | ||
|
|
83b6934621 | ||
|
|
7d6fcb7d65 | ||
|
|
35390baa92 | ||
|
|
f2efaa9992 | ||
|
|
258912b33a | ||
|
|
7f45bd6bc9 | ||
|
|
7073e9b4ba | ||
|
|
8c98346546 | ||
|
|
4300a32f6b | ||
|
|
ae1af54c6e | ||
|
|
34378164e0 | ||
|
|
47c44570fc | ||
|
|
417a12174c | ||
|
|
43a6bb8c5d | ||
|
|
18d18738be | ||
|
|
8af790ffba | ||
|
|
5e937601ce | ||
|
|
31548b10e5 | ||
|
|
ca7ebfb6e3 | ||
|
|
b98ae9ec6f | ||
|
|
8bf6ccaa25 | ||
|
|
65ac09cf50 | ||
|
|
d37100d059 | ||
|
|
50a9a6841d | ||
|
|
a45cccaa7a | ||
|
|
c732d395e1 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1,3 +1,4 @@
|
||||
hmq
|
||||
log
|
||||
log/*
|
||||
*.test
|
||||
8
.vscode/settings.json
vendored
Normal file
8
.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,8 @@
|
||||
{
|
||||
"go.lintFlags": [
|
||||
"--disable=all",
|
||||
"--enable=errcheck,varcheck,deadcode",
|
||||
"--enable=varcheck",
|
||||
"--enable=deadcode"
|
||||
]
|
||||
}
|
||||
11
Dcokerfile
11
Dcokerfile
@@ -1,11 +0,0 @@
|
||||
FROM alpine
|
||||
COPY hmq /
|
||||
COPY broker.config /
|
||||
COPY tls /tls
|
||||
COPY conf /conf
|
||||
|
||||
EXPOSE 1883
|
||||
EXPOSE 1888
|
||||
EXPOSE 1993
|
||||
|
||||
CMD ["/hmq"]
|
||||
13
Dockerfile
Normal file
13
Dockerfile
Normal file
@@ -0,0 +1,13 @@
|
||||
FROM golang:1.12 as builder
|
||||
WORKDIR /go/src/github.com/fhmq/hmq
|
||||
COPY . .
|
||||
COPY ./vendor .
|
||||
RUN CGO_ENABLED=0 go build -o hmq -a -ldflags '-extldflags "-static"' .
|
||||
|
||||
|
||||
FROM alpine:3.8
|
||||
WORKDIR /
|
||||
COPY --from=builder /go/src/github.com/fhmq/hmq/hmq .
|
||||
EXPOSE 1883
|
||||
|
||||
CMD ["/hmq"]
|
||||
2
lib/message/LICENSE → LICENSE
Executable file → Normal file
2
lib/message/LICENSE → LICENSE
Executable file → Normal file
@@ -1,4 +1,4 @@
|
||||
Apache License
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
148
README.md
148
README.md
@@ -3,25 +3,52 @@ Free and High Performance MQTT Broker
|
||||
|
||||
## About
|
||||
Golang MQTT Broker, Version 3.1.1, and Compatible
|
||||
for [eclipse paho client](https://github.com/eclipse?utf8=%E2%9C%93&q=mqtt&type=&language=)
|
||||
for [eclipse paho client](https://github.com/eclipse?utf8=%E2%9C%93&q=mqtt&type=&language=) and mosquitto-client
|
||||
|
||||
Download: [click here](https://github.com/fhmq/hmq/releases)
|
||||
|
||||
## RUNNING
|
||||
```bash
|
||||
$ git clone https://github.com/fhmq/hmq.git
|
||||
$ cd hmq
|
||||
$ go get github.com/fhmq/hmq
|
||||
$ cd $GOPATH/github.com/fhmq/hmq
|
||||
$ go run main.go
|
||||
```
|
||||
|
||||
### broker.config
|
||||
## Usage of hmq:
|
||||
~~~
|
||||
Usage: hmq [options]
|
||||
|
||||
Broker Options:
|
||||
-w, --worker <number> Worker num to process message, perfer (client num)/10. (default 1024)
|
||||
-p, --port <port> Use port for clients (default: 1883)
|
||||
--host <host> Network host to listen on. (default "0.0.0.0")
|
||||
-ws, --wsport <port> Use port for websocket monitoring
|
||||
-wsp,--wspath <path> Use path for websocket monitoring
|
||||
-c, --config <file> Configuration file
|
||||
|
||||
Logging Options:
|
||||
-d, --debug <bool> Enable debugging output (default false)
|
||||
-D Debug enabled
|
||||
|
||||
Cluster Options:
|
||||
-r, --router <rurl> Router who maintenance cluster info
|
||||
-cp, --clusterport <cluster-port> Cluster listen port for others
|
||||
|
||||
Common Options:
|
||||
-h, --help Show this message
|
||||
~~~
|
||||
|
||||
### hmq.config
|
||||
~~~
|
||||
{
|
||||
"workerNum": 4096,
|
||||
"port": "1883",
|
||||
"host": "0.0.0.0",
|
||||
"cluster": {
|
||||
"host": "0.0.0.0",
|
||||
"port": "1993",
|
||||
"routers": ["10.10.0.11:1993","10.10.0.12:1993"]
|
||||
"port": "1993"
|
||||
},
|
||||
"router": "127.0.0.1:9888",
|
||||
"wsPort": "1888",
|
||||
"wsPath": "/ws",
|
||||
"wsTLS": true,
|
||||
@@ -33,14 +60,16 @@ $ go run main.go
|
||||
"certFile": "tls/server/cert.pem",
|
||||
"keyFile": "tls/server/key.pem"
|
||||
},
|
||||
"acl":true,
|
||||
"aclConf":"conf/acl.conf"
|
||||
"plugins": {
|
||||
"auth": "authhttp",
|
||||
"bridge": "kafka"
|
||||
}
|
||||
}
|
||||
~~~
|
||||
|
||||
### Features and Future
|
||||
|
||||
* Supports QOS 0
|
||||
* Supports QOS 0 and 1
|
||||
|
||||
* Cluster Support
|
||||
|
||||
@@ -50,72 +79,47 @@ $ go run main.go
|
||||
|
||||
* Supports will messages
|
||||
|
||||
* Queue subscribe
|
||||
|
||||
* Websocket Support
|
||||
|
||||
* TLS/SSL Support
|
||||
|
||||
* Flexible ACL
|
||||
* AuthHTTP Support
|
||||
* Auth Connect
|
||||
* Auth ACL
|
||||
* Cache Support
|
||||
|
||||
### QUEUE SUBSCRIBE
|
||||
* Kafka Bridge Support
|
||||
* Action Deliver
|
||||
* Regexp Deliver
|
||||
|
||||
* HTTP API
|
||||
* Disconnect Connect (future more)
|
||||
|
||||
### Share SUBSCRIBE
|
||||
~~~
|
||||
| Prefix | Examples |
|
||||
| ------------- |---------------------------------|
|
||||
| $queue/ | mosquitto_sub -t ‘$queue/topic’ |
|
||||
| Prefix | Examples | Publish |
|
||||
| ------------------- |-------------------------------------------|--------------------------- --|
|
||||
| $share/<group>/topic | mosquitto_sub -t ‘$share/<group>/topic’ | mosquitto_pub -t ‘topic’ |
|
||||
~~~
|
||||
|
||||
### ACL Configure
|
||||
#### The ACL rules define:
|
||||
~~~
|
||||
Allow | type | value | pubsub | Topics
|
||||
~~~
|
||||
#### ACL Config
|
||||
~~~
|
||||
## type clientid , username, ipaddr
|
||||
##pub 1 , sub 2, pubsub 3
|
||||
## %c is clientid , %u is username
|
||||
allow ip 127.0.0.1 2 $SYS/#
|
||||
allow clientid 0001 3 #
|
||||
allow username admin 3 #
|
||||
allow username joy 3 /test,hello/world
|
||||
allow clientid * 1 toCloud/%c
|
||||
allow username * 1 toCloud/%u
|
||||
deny clientid * 3 #
|
||||
~~~
|
||||
### Cluster
|
||||
```bash
|
||||
1, start router for hmq (https://github.com/fhmq/router.git)
|
||||
$ go get github.com/fhmq/router
|
||||
$ cd $GOPATH/github.com/fhmq/router
|
||||
$ go run main.go
|
||||
2, config router in hmq.config ("router": "127.0.0.1:9888")
|
||||
|
||||
```
|
||||
|
||||
~~~
|
||||
#allow local sub $SYS topic
|
||||
allow ip 127.0.0.1 2 $SYS/#
|
||||
~~~
|
||||
~~~
|
||||
#allow client who's id with 0001 or username with admin pub sub all topic
|
||||
allow clientid 0001 3 #
|
||||
allow username admin 3 #
|
||||
~~~
|
||||
~~~
|
||||
#allow client with the username joy can pub sub topic '/test' and 'hello/world'
|
||||
allow username joy 3 /test,hello/world
|
||||
~~~
|
||||
~~~
|
||||
#allow all client pub the topic toCloud/{clientid/username}
|
||||
allow clientid * 1 toCloud/%c
|
||||
allow username * 1 toCloud/%u
|
||||
~~~
|
||||
~~~
|
||||
#deny all client pub sub all topic
|
||||
deny clientid * 3 #
|
||||
~~~
|
||||
Client match acl rule one by one
|
||||
~~~
|
||||
--------- --------- ---------
|
||||
Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | -->
|
||||
--------- --------- ---------
|
||||
| | |
|
||||
match match match
|
||||
\|/ \|/ \|/
|
||||
allow | deny allow | deny allow | deny
|
||||
~~~
|
||||
|
||||
### Online/Offline Notification
|
||||
```bash
|
||||
topic:
|
||||
$SYS/broker/connection/clients/<clientID>
|
||||
payload:
|
||||
{"clientID":"client001","online":true/false,"timestamp":"2018-10-25T09:32:32Z"}
|
||||
```
|
||||
|
||||
## Performance
|
||||
|
||||
@@ -128,4 +132,14 @@ Client -> | Rule1 | --nomatch--> | Rule2 | --nomatch--> | Rule3 | -->
|
||||
|
||||
## License
|
||||
|
||||
* Apache License Version 2.0
|
||||
* Apache License Version 2.0
|
||||
|
||||
|
||||
## Reference
|
||||
|
||||
* Surgermq.(https://github.com/surgemq/surgemq)
|
||||
|
||||
## Benchmark Tool
|
||||
|
||||
* https://github.com/inovex/mqtt-stresser
|
||||
* https://github.com/krylovsk/mqtt-benchmark
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
theme: jekyll-theme-slate
|
||||
@@ -1,80 +1,45 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package broker
|
||||
|
||||
import (
|
||||
"hmq/lib/acl"
|
||||
"strings"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
"github.com/fsnotify/fsnotify"
|
||||
)
|
||||
|
||||
const (
|
||||
PUB = 1
|
||||
SUB = 2
|
||||
SUB = "1"
|
||||
PUB = "2"
|
||||
)
|
||||
|
||||
func (c *client) CheckTopicAuth(typ int, topic string) bool {
|
||||
if !c.broker.config.Acl {
|
||||
return true
|
||||
func (b *Broker) CheckTopicAuth(action, username, topic string) bool {
|
||||
if b.auth != nil {
|
||||
if strings.HasPrefix(topic, "$SYS/broker/connection/clients/") {
|
||||
return true
|
||||
}
|
||||
|
||||
if strings.HasPrefix(topic, "$share/") && action == SUB {
|
||||
substr := groupCompile.FindStringSubmatch(topic)
|
||||
if len(substr) != 3 {
|
||||
return false
|
||||
}
|
||||
topic = substr[2]
|
||||
}
|
||||
|
||||
return b.auth.CheckACL(action, username, topic)
|
||||
}
|
||||
if strings.HasPrefix(topic, "$queue/") {
|
||||
topic = string([]byte(topic)[7:])
|
||||
if topic == "" {
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
func (b *Broker) CheckConnectAuth(clientID, username, password string) bool {
|
||||
if b.auth != nil {
|
||||
if clientID == "" || username == "" {
|
||||
return false
|
||||
}
|
||||
return b.auth.CheckConnect(clientID, username, password)
|
||||
}
|
||||
ip := c.info.remoteIP
|
||||
username := string(c.info.username)
|
||||
clientid := string(c.info.clientID)
|
||||
aclInfo := c.broker.AclConfig
|
||||
return acl.CheckTopicAuth(aclInfo, typ, ip, username, clientid, topic)
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
var (
|
||||
watchList = []string{"./conf"}
|
||||
)
|
||||
|
||||
func (b *Broker) handleFsEvent(event fsnotify.Event) error {
|
||||
switch event.Name {
|
||||
case b.config.AclConf:
|
||||
if event.Op&fsnotify.Write == fsnotify.Write ||
|
||||
event.Op&fsnotify.Create == fsnotify.Create {
|
||||
log.Info("text:handling acl config change event:", event)
|
||||
aclconfig, err := acl.AclConfigLoad(event.Name)
|
||||
if err != nil {
|
||||
log.Error("aclconfig change failed, load acl conf error: ", err)
|
||||
return err
|
||||
}
|
||||
b.AclConfig = aclconfig
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Broker) StartAclWatcher() {
|
||||
go func() {
|
||||
wch, e := fsnotify.NewWatcher()
|
||||
if e != nil {
|
||||
log.Error("start monitor acl config file error,", e)
|
||||
return
|
||||
}
|
||||
defer wch.Close()
|
||||
|
||||
for _, i := range watchList {
|
||||
if err := wch.Add(i); err != nil {
|
||||
log.Error("start monitor acl config file error,", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
log.Info("watching acl config file change...")
|
||||
for {
|
||||
select {
|
||||
case evt := <-wch.Events:
|
||||
b.handleFsEvent(evt)
|
||||
case err := <-wch.Errors:
|
||||
log.Error("error:", err.Error())
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
15
broker/bridge.go
Normal file
15
broker/bridge.go
Normal file
@@ -0,0 +1,15 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"github.com/fhmq/hmq/plugins/bridge"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (b *Broker) Publish(e *bridge.Elements) {
|
||||
if b.bridgeMQ != nil {
|
||||
err := b.bridgeMQ.Publish(e)
|
||||
if err != nil {
|
||||
log.Error("send message to mq error.", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
719
broker/broker.go
719
broker/broker.go
@@ -1,107 +1,212 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package broker
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"hmq/lib/acl"
|
||||
"hmq/lib/message"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync/atomic"
|
||||
"runtime/debug"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/websocket"
|
||||
"github.com/fhmq/hmq/plugins/bridge"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
"github.com/fhmq/hmq/plugins/auth"
|
||||
|
||||
"github.com/fhmq/hmq/broker/lib/sessions"
|
||||
"github.com/fhmq/hmq/broker/lib/topics"
|
||||
|
||||
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||
"github.com/fhmq/hmq/pool"
|
||||
"github.com/shirou/gopsutil/mem"
|
||||
"go.uber.org/zap"
|
||||
"golang.org/x/net/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
MessagePoolNum = 1024
|
||||
MessagePoolMessageNum = 1024
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
client *client
|
||||
packet packets.ControlPacket
|
||||
}
|
||||
|
||||
type Broker struct {
|
||||
id string
|
||||
cid uint64
|
||||
config *Config
|
||||
tlsConfig *tls.Config
|
||||
AclConfig *acl.ACLConfig
|
||||
clients cMap
|
||||
routes cMap
|
||||
remotes cMap
|
||||
sl *Sublist
|
||||
rl *RetainList
|
||||
queues map[string]int
|
||||
id string
|
||||
mu sync.Mutex
|
||||
config *Config
|
||||
tlsConfig *tls.Config
|
||||
wpool *pool.WorkerPool
|
||||
clients sync.Map
|
||||
routes sync.Map
|
||||
remotes sync.Map
|
||||
nodes map[string]interface{}
|
||||
clusterPool chan *Message
|
||||
topicsMgr *topics.Manager
|
||||
sessionMgr *sessions.Manager
|
||||
auth auth.Auth
|
||||
bridgeMQ bridge.BridgeMQ
|
||||
}
|
||||
|
||||
func newMessagePool() []chan *Message {
|
||||
pool := make([]chan *Message, 0)
|
||||
for i := 0; i < MessagePoolNum; i++ {
|
||||
ch := make(chan *Message, MessagePoolMessageNum)
|
||||
pool = append(pool, ch)
|
||||
}
|
||||
return pool
|
||||
}
|
||||
|
||||
func NewBroker(config *Config) (*Broker, error) {
|
||||
b := &Broker{
|
||||
id: GenUniqueId(),
|
||||
config: config,
|
||||
sl: NewSublist(),
|
||||
rl: NewRetainList(),
|
||||
queues: make(map[string]int),
|
||||
clients: NewClientMap(),
|
||||
routes: NewClientMap(),
|
||||
remotes: NewClientMap(),
|
||||
if config == nil {
|
||||
config = DefaultConfig
|
||||
}
|
||||
|
||||
b := &Broker{
|
||||
id: GenUniqueId(),
|
||||
config: config,
|
||||
wpool: pool.New(config.Worker),
|
||||
nodes: make(map[string]interface{}),
|
||||
clusterPool: make(chan *Message),
|
||||
}
|
||||
|
||||
var err error
|
||||
b.topicsMgr, err = topics.NewManager("mem")
|
||||
if err != nil {
|
||||
log.Error("new topic manager error", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
b.sessionMgr, err = sessions.NewManager("mem")
|
||||
if err != nil {
|
||||
log.Error("new session manager error", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if b.config.TlsPort != "" {
|
||||
tlsconfig, err := NewTLSConfig(b.config.TlsInfo)
|
||||
if err != nil {
|
||||
log.Error("new tlsConfig error: ", err)
|
||||
log.Error("new tlsConfig error", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
b.tlsConfig = tlsconfig
|
||||
}
|
||||
if b.config.Acl {
|
||||
aclconfig, err := acl.AclConfigLoad(b.config.AclConf)
|
||||
if err != nil {
|
||||
log.Error("Load acl conf error: ", err)
|
||||
return nil, err
|
||||
}
|
||||
b.AclConfig = aclconfig
|
||||
b.StartAclWatcher()
|
||||
}
|
||||
|
||||
b.auth = auth.NewAuth(b.config.Plugin.Auth)
|
||||
b.bridgeMQ = bridge.NewBridgeMQ(b.config.Plugin.Bridge)
|
||||
|
||||
return b, nil
|
||||
}
|
||||
|
||||
func (b *Broker) SubmitWork(clientId string, msg *Message) {
|
||||
if b.wpool == nil {
|
||||
b.wpool = pool.New(b.config.Worker)
|
||||
}
|
||||
|
||||
if msg.client.typ == CLUSTER {
|
||||
b.clusterPool <- msg
|
||||
} else {
|
||||
b.wpool.Submit(clientId, func() {
|
||||
ProcessMessage(msg)
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (b *Broker) Start() {
|
||||
if b == nil {
|
||||
log.Error("broker is null")
|
||||
return
|
||||
}
|
||||
|
||||
go InitHTTPMoniter(b)
|
||||
|
||||
//listen clinet over tcp
|
||||
if b.config.Port != "" {
|
||||
go b.StartListening(CLIENT)
|
||||
go b.StartClientListening(false)
|
||||
}
|
||||
|
||||
//listen for cluster
|
||||
if b.config.Cluster.Port != "" {
|
||||
go b.StartListening(ROUTER)
|
||||
go b.StartClusterListening()
|
||||
}
|
||||
|
||||
//listen for websocket
|
||||
if b.config.WsPort != "" {
|
||||
go b.StartWebsocketListening()
|
||||
}
|
||||
|
||||
//listen client over tls
|
||||
if b.config.TlsPort != "" {
|
||||
go b.StartTLSListening()
|
||||
go b.StartClientListening(true)
|
||||
}
|
||||
|
||||
//connect on other node in cluster
|
||||
if b.config.Router != "" {
|
||||
go b.processClusterInfo()
|
||||
b.ConnectToDiscovery()
|
||||
}
|
||||
|
||||
//system monitor
|
||||
go StateMonitor()
|
||||
|
||||
}
|
||||
|
||||
func StateMonitor() {
|
||||
v, _ := mem.VirtualMemory()
|
||||
timeSticker := time.NewTicker(time.Second * 30)
|
||||
for {
|
||||
select {
|
||||
case <-timeSticker.C:
|
||||
if v.UsedPercent > 75 {
|
||||
debug.FreeOSMemory()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) StartWebsocketListening() {
|
||||
path := b.config.WsPath
|
||||
hp := ":" + b.config.WsPort
|
||||
log.Info("Start Webscoker Listening on ", hp, path)
|
||||
log.Info("Start Websocket Listener on:", zap.String("hp", hp), zap.String("path", path))
|
||||
http.Handle(path, websocket.Handler(b.wsHandler))
|
||||
err := http.ListenAndServe(hp, nil)
|
||||
var err error
|
||||
if b.config.WsTLS {
|
||||
err = http.ListenAndServeTLS(hp, b.config.TlsInfo.CertFile, b.config.TlsInfo.KeyFile, nil)
|
||||
} else {
|
||||
err = http.ListenAndServe(hp, nil)
|
||||
}
|
||||
if err != nil {
|
||||
log.Error("ListenAndServe: " + err.Error())
|
||||
log.Error("ListenAndServe:" + err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) wsHandler(ws *websocket.Conn) {
|
||||
atomic.AddUint64(&b.cid, 1)
|
||||
go b.handleConnection(CLIENT, ws, b.cid)
|
||||
// io.Copy(ws, ws)
|
||||
ws.PayloadType = websocket.BinaryFrame
|
||||
b.handleConnection(CLIENT, ws)
|
||||
}
|
||||
|
||||
func (b *Broker) StartTLSListening() {
|
||||
hp := b.config.TlsHost + ":" + b.config.TlsPort
|
||||
log.Info("Start TLS Listening client on ", hp)
|
||||
|
||||
l, e := tls.Listen("tcp", hp, b.tlsConfig)
|
||||
if e != nil {
|
||||
log.Error("Error listening on ", e)
|
||||
func (b *Broker) StartClientListening(Tls bool) {
|
||||
var hp string
|
||||
var err error
|
||||
var l net.Listener
|
||||
if Tls {
|
||||
hp = b.config.TlsHost + ":" + b.config.TlsPort
|
||||
l, err = tls.Listen("tcp", hp, b.tlsConfig)
|
||||
log.Info("Start TLS Listening client on ", zap.String("hp", hp))
|
||||
} else {
|
||||
hp := b.config.Host + ":" + b.config.Port
|
||||
l, err = net.Listen("tcp", hp)
|
||||
log.Info("Start Listening client on ", zap.String("hp", hp))
|
||||
}
|
||||
if err != nil {
|
||||
log.Error("Error listening on ", zap.Error(err))
|
||||
return
|
||||
}
|
||||
tmpDelay := 10 * ACCEPT_MIN_SLEEP
|
||||
@@ -110,36 +215,59 @@ func (b *Broker) StartTLSListening() {
|
||||
if err != nil {
|
||||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
|
||||
ne, tmpDelay/time.Millisecond)
|
||||
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
|
||||
time.Sleep(tmpDelay)
|
||||
tmpDelay *= 2
|
||||
if tmpDelay > ACCEPT_MAX_SLEEP {
|
||||
tmpDelay = ACCEPT_MAX_SLEEP
|
||||
}
|
||||
} else {
|
||||
log.Error("Accept error: %v", err)
|
||||
log.Error("Accept error: %v", zap.Error(err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
tmpDelay = ACCEPT_MIN_SLEEP
|
||||
atomic.AddUint64(&b.cid, 1)
|
||||
go b.handleConnection(CLIENT, conn, b.cid)
|
||||
go b.handleConnection(CLIENT, conn)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) StartListening(typ int) {
|
||||
var hp string
|
||||
if typ == CLIENT {
|
||||
hp = b.config.Host + ":" + b.config.Port
|
||||
log.Info("Start Listening client on ", hp)
|
||||
} else if typ == ROUTER {
|
||||
hp = b.config.Cluster.Host + ":" + b.config.Cluster.Port
|
||||
log.Info("Start Listening cluster on ", hp)
|
||||
func (b *Broker) Handshake(conn net.Conn) bool {
|
||||
|
||||
nc := tls.Server(conn, b.tlsConfig)
|
||||
time.AfterFunc(DEFAULT_TLS_TIMEOUT, func() { TlsTimeout(nc) })
|
||||
nc.SetReadDeadline(time.Now().Add(DEFAULT_TLS_TIMEOUT))
|
||||
|
||||
// Force handshake
|
||||
if err := nc.Handshake(); err != nil {
|
||||
log.Error("TLS handshake error, ", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
nc.SetReadDeadline(time.Time{})
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
func TlsTimeout(conn *tls.Conn) {
|
||||
nc := conn
|
||||
// Check if already closed
|
||||
if nc == nil {
|
||||
return
|
||||
}
|
||||
cs := nc.ConnectionState()
|
||||
if !cs.HandshakeComplete {
|
||||
log.Error("TLS handshake timeout")
|
||||
nc.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) StartClusterListening() {
|
||||
var hp string = b.config.Cluster.Host + ":" + b.config.Cluster.Port
|
||||
log.Info("Start Listening cluster on ", zap.String("hp", hp))
|
||||
|
||||
l, e := net.Listen("tcp", hp)
|
||||
if e != nil {
|
||||
log.Error("Error listening on ", e)
|
||||
log.Error("Error listening on ", zap.Error(e))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -149,60 +277,95 @@ func (b *Broker) StartListening(typ int) {
|
||||
if err != nil {
|
||||
if ne, ok := err.(net.Error); ok && ne.Temporary() {
|
||||
log.Error("Temporary Client Accept Error(%v), sleeping %dms",
|
||||
ne, tmpDelay/time.Millisecond)
|
||||
zap.Error(ne), zap.Duration("sleeping", tmpDelay/time.Millisecond))
|
||||
time.Sleep(tmpDelay)
|
||||
tmpDelay *= 2
|
||||
if tmpDelay > ACCEPT_MAX_SLEEP {
|
||||
tmpDelay = ACCEPT_MAX_SLEEP
|
||||
}
|
||||
} else {
|
||||
log.Error("Accept error: %v", err)
|
||||
log.Error("Accept error: %v", zap.Error(err))
|
||||
}
|
||||
continue
|
||||
}
|
||||
tmpDelay = ACCEPT_MIN_SLEEP
|
||||
atomic.AddUint64(&b.cid, 1)
|
||||
go b.handleConnection(typ, conn, b.cid)
|
||||
|
||||
go b.handleConnection(ROUTER, conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
|
||||
func (b *Broker) handleConnection(typ int, conn net.Conn) {
|
||||
//process connect packet
|
||||
buf, err := ReadPacket(conn)
|
||||
packet, err := packets.ReadPacket(conn)
|
||||
if err != nil {
|
||||
log.Error("read connect packet error: ", err)
|
||||
log.Error("read connect packet error: ", zap.Error(err))
|
||||
return
|
||||
}
|
||||
connMsg, err := DecodeConnectMessage(buf)
|
||||
if err != nil {
|
||||
log.Error(err)
|
||||
if packet == nil {
|
||||
log.Error("received nil packet")
|
||||
return
|
||||
}
|
||||
msg, ok := packet.(*packets.ConnectPacket)
|
||||
if !ok {
|
||||
log.Error("received msg that was not Connect")
|
||||
return
|
||||
}
|
||||
|
||||
connack := message.NewConnackMessage()
|
||||
connack.SetReturnCode(message.ConnectionAccepted)
|
||||
ack, _ := EncodeMessage(connack)
|
||||
err1 := WriteBuffer(conn, ack)
|
||||
if err1 != nil {
|
||||
log.Error("send connack error, ", err1)
|
||||
log.Info("read connect from ", zap.String("clientID", msg.ClientIdentifier))
|
||||
|
||||
connack := packets.NewControlPacket(packets.Connack).(*packets.ConnackPacket)
|
||||
connack.SessionPresent = msg.CleanSession
|
||||
connack.ReturnCode = msg.Validate()
|
||||
|
||||
if connack.ReturnCode != packets.Accepted {
|
||||
err = connack.Write(conn)
|
||||
if err != nil {
|
||||
log.Error("send connack error, ", zap.Error(err), zap.String("clientID", msg.ClientIdentifier))
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
willmsg := message.NewPublishMessage()
|
||||
if connMsg.WillFlag() {
|
||||
willmsg.SetQoS(connMsg.WillQos())
|
||||
willmsg.SetPayload(connMsg.WillMessage())
|
||||
willmsg.SetRetain(connMsg.WillRetain())
|
||||
willmsg.SetTopic(connMsg.WillTopic())
|
||||
willmsg.SetDup(false)
|
||||
if typ == CLIENT && !b.CheckConnectAuth(string(msg.ClientIdentifier), string(msg.Username), string(msg.Password)) {
|
||||
connack.ReturnCode = packets.ErrRefusedNotAuthorised
|
||||
err = connack.Write(conn)
|
||||
if err != nil {
|
||||
log.Error("send connack error, ", zap.Error(err), zap.String("clientID", msg.ClientIdentifier))
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
err = connack.Write(conn)
|
||||
if err != nil {
|
||||
log.Error("send connack error, ", zap.Error(err), zap.String("clientID", msg.ClientIdentifier))
|
||||
return
|
||||
}
|
||||
|
||||
if typ == CLIENT {
|
||||
b.Publish(&bridge.Elements{
|
||||
ClientID: string(msg.ClientIdentifier),
|
||||
Username: string(msg.Username),
|
||||
Action: bridge.Connect,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
willmsg := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||
if msg.WillFlag {
|
||||
willmsg.Qos = msg.WillQos
|
||||
willmsg.TopicName = msg.WillTopic
|
||||
willmsg.Retain = msg.WillRetain
|
||||
willmsg.Payload = msg.WillMessage
|
||||
willmsg.Dup = msg.Dup
|
||||
} else {
|
||||
willmsg = nil
|
||||
}
|
||||
info := info{
|
||||
clientID: connMsg.ClientId(),
|
||||
username: connMsg.Username(),
|
||||
password: connMsg.Password(),
|
||||
keepalive: connMsg.KeepAlive(),
|
||||
clientID: msg.ClientIdentifier,
|
||||
username: msg.Username,
|
||||
password: msg.Password,
|
||||
keepalive: msg.Keepalive,
|
||||
willMsg: willmsg,
|
||||
}
|
||||
|
||||
@@ -212,113 +375,257 @@ func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) {
|
||||
conn: conn,
|
||||
info: info,
|
||||
}
|
||||
|
||||
c.init()
|
||||
|
||||
var msgPool *MessagePool
|
||||
err = b.getSession(c, msg, connack)
|
||||
if err != nil {
|
||||
log.Error("get session error: ", zap.String("clientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
|
||||
cid := c.info.clientID
|
||||
|
||||
var exist bool
|
||||
var old *client
|
||||
cid := string(c.info.clientID)
|
||||
if typ == CLIENT {
|
||||
old, exist = b.clients.Update(cid, c)
|
||||
msgPool = MSGPool[idx%MessagePoolNum].GetPool()
|
||||
} else if typ == ROUTER {
|
||||
old, exist = b.routes.Update(cid, c)
|
||||
msgPool = MSGPool[MessagePoolNum].GetPool()
|
||||
}
|
||||
if exist {
|
||||
log.Warn("client or routers exists, close old...")
|
||||
old.Close()
|
||||
}
|
||||
c.readLoop(msgPool)
|
||||
}
|
||||
var old interface{}
|
||||
|
||||
func (b *Broker) ConnectToRouters() {
|
||||
for i := 0; i < len(b.config.Cluster.Routes); i++ {
|
||||
url := b.config.Cluster.Routes[i]
|
||||
go b.connectRouter(url, "")
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) connectRouter(url, remoteID string) {
|
||||
for {
|
||||
conn, err := net.Dial("tcp", url)
|
||||
if err != nil {
|
||||
log.Error("Error trying to connect to route: ", err)
|
||||
select {
|
||||
case <-time.After(DEFAULT_ROUTE_CONNECT):
|
||||
log.Debug("Connect to route timeout ,retry...")
|
||||
continue
|
||||
switch typ {
|
||||
case CLIENT:
|
||||
old, exist = b.clients.Load(cid)
|
||||
if exist {
|
||||
log.Warn("client exist, close old...", zap.String("clientID", c.info.clientID))
|
||||
ol, ok := old.(*client)
|
||||
if ok {
|
||||
ol.Close()
|
||||
}
|
||||
}
|
||||
route := &route{
|
||||
remoteID: remoteID,
|
||||
remoteUrl: url,
|
||||
b.clients.Store(cid, c)
|
||||
|
||||
b.OnlineOfflineNotification(cid, true)
|
||||
case ROUTER:
|
||||
old, exist = b.routes.Load(cid)
|
||||
if exist {
|
||||
log.Warn("router exist, close old...")
|
||||
ol, ok := old.(*client)
|
||||
if ok {
|
||||
ol.Close()
|
||||
}
|
||||
}
|
||||
cid := GenUniqueId()
|
||||
info := info{
|
||||
clientID: []byte(cid),
|
||||
}
|
||||
c := &client{
|
||||
typ: REMOTE,
|
||||
conn: conn,
|
||||
route: route,
|
||||
info: info,
|
||||
}
|
||||
b.remotes.Set(cid, c)
|
||||
c.SendConnect()
|
||||
c.SendInfo()
|
||||
// s.createRemote(conn, route)
|
||||
msgPool := MSGPool[(MessagePoolNum + 1)].GetPool()
|
||||
c.readLoop(msgPool)
|
||||
b.routes.Store(cid, c)
|
||||
}
|
||||
|
||||
c.readLoop()
|
||||
}
|
||||
|
||||
func (b *Broker) ConnectToDiscovery() {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
var tempDelay time.Duration = 0
|
||||
for {
|
||||
conn, err = net.Dial("tcp", b.config.Router)
|
||||
if err != nil {
|
||||
log.Error("Error trying to connect to route: ", zap.Error(err))
|
||||
log.Debug("Connect to route timeout ,retry...")
|
||||
|
||||
if 0 == tempDelay {
|
||||
tempDelay = 1 * time.Second
|
||||
} else {
|
||||
tempDelay *= 2
|
||||
}
|
||||
|
||||
if max := 20 * time.Second; tempDelay > max {
|
||||
tempDelay = max
|
||||
}
|
||||
time.Sleep(tempDelay)
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
log.Debug("connect to router success :", zap.String("Router", b.config.Router))
|
||||
|
||||
cid := b.id
|
||||
info := info{
|
||||
clientID: cid,
|
||||
keepalive: 60,
|
||||
}
|
||||
|
||||
c := &client{
|
||||
typ: CLUSTER,
|
||||
broker: b,
|
||||
conn: conn,
|
||||
info: info,
|
||||
}
|
||||
|
||||
c.init()
|
||||
|
||||
c.SendConnect()
|
||||
c.SendInfo()
|
||||
|
||||
go c.readLoop()
|
||||
go c.StartPing()
|
||||
}
|
||||
|
||||
func (b *Broker) processClusterInfo() {
|
||||
for {
|
||||
msg, ok := <-b.clusterPool
|
||||
if !ok {
|
||||
log.Error("read message from cluster channel error")
|
||||
return
|
||||
}
|
||||
ProcessMessage(msg)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (b *Broker) connectRouter(id, addr string) {
|
||||
var conn net.Conn
|
||||
var err error
|
||||
var timeDelay time.Duration = 0
|
||||
retryTimes := 0
|
||||
max := 32 * time.Second
|
||||
for {
|
||||
|
||||
if !b.checkNodeExist(id, addr) {
|
||||
return
|
||||
}
|
||||
|
||||
conn, err = net.Dial("tcp", addr)
|
||||
if err != nil {
|
||||
log.Error("Error trying to connect to route: ", zap.Error(err))
|
||||
|
||||
if retryTimes > 50 {
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug("Connect to route timeout ,retry...")
|
||||
|
||||
if 0 == timeDelay {
|
||||
timeDelay = 1 * time.Second
|
||||
} else {
|
||||
timeDelay *= 2
|
||||
}
|
||||
|
||||
if timeDelay > max {
|
||||
timeDelay = max
|
||||
}
|
||||
time.Sleep(timeDelay)
|
||||
retryTimes++
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
route := route{
|
||||
remoteID: id,
|
||||
remoteUrl: addr,
|
||||
}
|
||||
cid := GenUniqueId()
|
||||
|
||||
info := info{
|
||||
clientID: cid,
|
||||
keepalive: 60,
|
||||
}
|
||||
|
||||
c := &client{
|
||||
broker: b,
|
||||
typ: REMOTE,
|
||||
conn: conn,
|
||||
route: route,
|
||||
info: info,
|
||||
}
|
||||
c.init()
|
||||
b.remotes.Store(cid, c)
|
||||
|
||||
c.SendConnect()
|
||||
|
||||
// mpool := b.messagePool[fnv1a.HashString64(cid)%MessagePoolNum]
|
||||
go c.readLoop()
|
||||
go c.StartPing()
|
||||
|
||||
}
|
||||
|
||||
func (b *Broker) checkNodeExist(id, url string) bool {
|
||||
if id == b.id {
|
||||
return false
|
||||
}
|
||||
|
||||
for k, v := range b.nodes {
|
||||
if k == id {
|
||||
return true
|
||||
}
|
||||
|
||||
//skip
|
||||
l, ok := v.(string)
|
||||
if ok {
|
||||
if url == l {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (b *Broker) CheckRemoteExist(remoteID, url string) bool {
|
||||
exist := false
|
||||
remotes := b.remotes.Items()
|
||||
for _, v := range remotes {
|
||||
if v.route.remoteUrl == url {
|
||||
// if v.route.remoteID == "" || v.route.remoteID != remoteID {
|
||||
v.route.remoteID = remoteID
|
||||
// }
|
||||
exist = true
|
||||
break
|
||||
b.remotes.Range(func(key, value interface{}) bool {
|
||||
v, ok := value.(*client)
|
||||
if ok {
|
||||
if v.route.remoteUrl == url {
|
||||
v.route.remoteID = remoteID
|
||||
exist = true
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return exist
|
||||
}
|
||||
|
||||
func (b *Broker) SendLocalSubsToRouter(c *client) {
|
||||
clients := b.clients.Items()
|
||||
subMsg := message.NewSubscribeMessage()
|
||||
for _, client := range clients {
|
||||
subs := client.subs
|
||||
for _, sub := range subs {
|
||||
subMsg.AddTopic(sub.topic, sub.qos)
|
||||
subInfo := packets.NewControlPacket(packets.Subscribe).(*packets.SubscribePacket)
|
||||
b.clients.Range(func(key, value interface{}) bool {
|
||||
client, ok := value.(*client)
|
||||
if ok {
|
||||
subs := client.subMap
|
||||
for _, sub := range subs {
|
||||
subInfo.Topics = append(subInfo.Topics, sub.topic)
|
||||
subInfo.Qoss = append(subInfo.Qoss, sub.qos)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
if len(subInfo.Topics) > 0 {
|
||||
err := c.WriterPacket(subInfo)
|
||||
if err != nil {
|
||||
log.Error("Send localsubs To Router error :", zap.Error(err))
|
||||
}
|
||||
}
|
||||
err := c.writeMessage(subMsg)
|
||||
if err != nil {
|
||||
log.Error("Send localsubs To Router error :", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) BroadcastInfoMessage(remoteID string, msg message.Message) {
|
||||
remotes := b.remotes.Items()
|
||||
for _, r := range remotes {
|
||||
if r.route.remoteID == remoteID {
|
||||
continue
|
||||
func (b *Broker) BroadcastInfoMessage(remoteID string, msg *packets.PublishPacket) {
|
||||
b.routes.Range(func(key, value interface{}) bool {
|
||||
r, ok := value.(*client)
|
||||
if ok {
|
||||
if r.route.remoteID == remoteID {
|
||||
return true
|
||||
}
|
||||
r.WriterPacket(msg)
|
||||
}
|
||||
r.writeMessage(msg)
|
||||
}
|
||||
return true
|
||||
|
||||
})
|
||||
// log.Info("BroadcastInfoMessage success ")
|
||||
}
|
||||
|
||||
func (b *Broker) BroadcastSubOrUnsubMessage(buf []byte) {
|
||||
remotes := b.remotes.Items()
|
||||
for _, r := range remotes {
|
||||
r.writeBuffer(buf)
|
||||
}
|
||||
func (b *Broker) BroadcastSubOrUnsubMessage(packet packets.ControlPacket) {
|
||||
|
||||
b.routes.Range(func(key, value interface{}) bool {
|
||||
r, ok := value.(*client)
|
||||
if ok {
|
||||
r.WriterPacket(packet)
|
||||
}
|
||||
return true
|
||||
})
|
||||
// log.Info("BroadcastSubscribeMessage remotes: ", s.remotes)
|
||||
}
|
||||
|
||||
@@ -327,48 +634,54 @@ func (b *Broker) removeClient(c *client) {
|
||||
typ := c.typ
|
||||
switch typ {
|
||||
case CLIENT:
|
||||
b.clients.Remove(clientId)
|
||||
b.clients.Delete(clientId)
|
||||
case ROUTER:
|
||||
b.routes.Remove(clientId)
|
||||
b.routes.Delete(clientId)
|
||||
case REMOTE:
|
||||
b.remotes.Remove(clientId)
|
||||
b.remotes.Delete(clientId)
|
||||
}
|
||||
// log.Info("delete client ,", clientId)
|
||||
}
|
||||
|
||||
func (b *Broker) ProcessPublishMessage(msg *message.PublishMessage) {
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
topic := string(msg.Topic())
|
||||
|
||||
r := b.sl.Match(topic)
|
||||
// log.Info("psubs num: ", len(r.psubs))
|
||||
if len(r.qsubs) == 0 && len(r.psubs) == 0 {
|
||||
func (b *Broker) PublishMessage(packet *packets.PublishPacket) {
|
||||
var subs []interface{}
|
||||
var qoss []byte
|
||||
b.mu.Lock()
|
||||
err := b.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &subs, &qoss)
|
||||
b.mu.Unlock()
|
||||
if err != nil {
|
||||
log.Error("search sub client error, ", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range r.psubs {
|
||||
if sub != nil {
|
||||
err := sub.client.writeMessage(msg)
|
||||
for _, sub := range subs {
|
||||
s, ok := sub.(*subscription)
|
||||
if ok {
|
||||
err := s.client.WriterPacket(packet)
|
||||
if err != nil {
|
||||
log.Error("process message for psub error, ", err)
|
||||
log.Error("write message error, ", zap.Error(err))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for i, sub := range r.qsubs {
|
||||
// s.qmu.Lock()
|
||||
if cnt, exist := b.queues[string(sub.topic)]; exist && i == cnt {
|
||||
if sub != nil {
|
||||
err := sub.client.writeMessage(msg)
|
||||
if err != nil {
|
||||
log.Error("process will message for qsub error, ", err)
|
||||
}
|
||||
}
|
||||
b.queues[topic] = (b.queues[topic] + 1) % len(r.qsubs)
|
||||
break
|
||||
}
|
||||
// s.qmu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) BroadcastUnSubscribe(subs map[string]*subscription) {
|
||||
|
||||
unsub := packets.NewControlPacket(packets.Unsubscribe).(*packets.UnsubscribePacket)
|
||||
for topic, _ := range subs {
|
||||
unsub.Topics = append(unsub.Topics, topic)
|
||||
}
|
||||
|
||||
if len(unsub.Topics) > 0 {
|
||||
b.BroadcastSubOrUnsubMessage(unsub)
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Broker) OnlineOfflineNotification(clientID string, online bool) {
|
||||
packet := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||
packet.TopicName = "$SYS/broker/connection/clients/" + clientID
|
||||
packet.Qos = 0
|
||||
packet.Payload = []byte(fmt.Sprintf(`{"clientID":"%s","online":%v,"timestamp":"%s"}`, clientID, online, time.Now().UTC().Format(time.RFC3339)))
|
||||
|
||||
b.PublishMessage(packet)
|
||||
}
|
||||
|
||||
814
broker/client.go
814
broker/client.go
@@ -1,49 +1,84 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package broker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"hmq/lib/message"
|
||||
"math/rand"
|
||||
"net"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
"github.com/fhmq/hmq/broker/lib/sessions"
|
||||
"github.com/fhmq/hmq/broker/lib/topics"
|
||||
"github.com/fhmq/hmq/plugins/bridge"
|
||||
|
||||
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
const (
|
||||
// special pub topic for cluster info BrokerInfoTopic
|
||||
BrokerInfoTopic = "broker001info/brokerinfo"
|
||||
BrokerInfoTopic = "broker000100101info"
|
||||
// CLIENT is an end user.
|
||||
CLIENT = 0
|
||||
// ROUTER is another router in the cluster.
|
||||
ROUTER = 1
|
||||
//REMOTE is the router connect to other cluster
|
||||
REMOTE = 2
|
||||
REMOTE = 2
|
||||
CLUSTER = 3
|
||||
)
|
||||
|
||||
const (
|
||||
_GroupTopicRegexp = `^\$share/([0-9a-zA-Z_-]+)/(.*)$`
|
||||
)
|
||||
|
||||
const (
|
||||
Connected = 1
|
||||
Disconnected = 2
|
||||
)
|
||||
|
||||
var (
|
||||
groupCompile = regexp.MustCompile(_GroupTopicRegexp)
|
||||
)
|
||||
|
||||
type client struct {
|
||||
typ int
|
||||
mu sync.Mutex
|
||||
broker *Broker
|
||||
conn net.Conn
|
||||
info info
|
||||
route *route
|
||||
subs map[string]*subscription
|
||||
typ int
|
||||
mu sync.Mutex
|
||||
broker *Broker
|
||||
conn net.Conn
|
||||
info info
|
||||
route route
|
||||
status int
|
||||
ctx context.Context
|
||||
cancelFunc context.CancelFunc
|
||||
session *sessions.Session
|
||||
subMap map[string]*subscription
|
||||
topicsMgr *topics.Manager
|
||||
subs []interface{}
|
||||
qoss []byte
|
||||
rmsgs []*packets.PublishPacket
|
||||
routeSubMap map[string]uint64
|
||||
}
|
||||
|
||||
type subscription struct {
|
||||
client *client
|
||||
topic []byte
|
||||
qos byte
|
||||
queue bool
|
||||
client *client
|
||||
topic string
|
||||
qos byte
|
||||
share bool
|
||||
groupName string
|
||||
}
|
||||
|
||||
type info struct {
|
||||
clientID []byte
|
||||
username []byte
|
||||
clientID string
|
||||
username string
|
||||
password []byte
|
||||
keepalive uint16
|
||||
willMsg *message.PublishMessage
|
||||
willMsg *packets.PublishPacket
|
||||
localIP string
|
||||
remoteIP string
|
||||
}
|
||||
@@ -53,386 +88,577 @@ type route struct {
|
||||
remoteUrl string
|
||||
}
|
||||
|
||||
var (
|
||||
DisconnectdPacket = packets.NewControlPacket(packets.Disconnect).(*packets.DisconnectPacket)
|
||||
r = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
)
|
||||
|
||||
func (c *client) init() {
|
||||
c.subs = make(map[string]*subscription, 10)
|
||||
c.status = Connected
|
||||
c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0]
|
||||
c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0]
|
||||
c.ctx, c.cancelFunc = context.WithCancel(context.Background())
|
||||
c.subMap = make(map[string]*subscription)
|
||||
c.topicsMgr = c.broker.topicsMgr
|
||||
}
|
||||
|
||||
func (c *client) readLoop(msgPool *MessagePool) {
|
||||
func (c *client) readLoop() {
|
||||
nc := c.conn
|
||||
if nc == nil || msgPool == nil {
|
||||
b := c.broker
|
||||
if nc == nil || b == nil {
|
||||
return
|
||||
}
|
||||
msg := &Message{}
|
||||
|
||||
keepAlive := time.Second * time.Duration(c.info.keepalive)
|
||||
timeOut := keepAlive + (keepAlive / 2)
|
||||
|
||||
for {
|
||||
buf, err := ReadPacket(nc)
|
||||
if err != nil {
|
||||
log.Error("read packet error: ", err)
|
||||
c.Close()
|
||||
select {
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
default:
|
||||
//add read timeout
|
||||
if err := nc.SetReadDeadline(time.Now().Add(timeOut)); err != nil {
|
||||
log.Error("set read timeout error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
msg := &Message{
|
||||
client: c,
|
||||
packet: DisconnectdPacket,
|
||||
}
|
||||
b.SubmitWork(c.info.clientID, msg)
|
||||
return
|
||||
}
|
||||
|
||||
packet, err := packets.ReadPacket(nc)
|
||||
if err != nil {
|
||||
log.Error("read packet error: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
msg := &Message{
|
||||
client: c,
|
||||
packet: DisconnectdPacket,
|
||||
}
|
||||
b.SubmitWork(c.info.clientID, msg)
|
||||
return
|
||||
}
|
||||
|
||||
msg := &Message{
|
||||
client: c,
|
||||
packet: packet,
|
||||
}
|
||||
b.SubmitWork(c.info.clientID, msg)
|
||||
}
|
||||
msg.client = c
|
||||
msg.msg = buf
|
||||
msgPool.queue <- msg
|
||||
}
|
||||
msgPool.Reduce()
|
||||
|
||||
}
|
||||
|
||||
func ProcessMessage(msg *Message) {
|
||||
buf := msg.msg
|
||||
c := msg.client
|
||||
if c == nil || buf == nil {
|
||||
ca := msg.packet
|
||||
if ca == nil {
|
||||
return
|
||||
}
|
||||
msgType := uint8(buf[0] & 0xF0 >> 4)
|
||||
switch msgType {
|
||||
case CONNACK:
|
||||
// log.Info("Recv conack message..........")
|
||||
c.ProcessConnAck(buf)
|
||||
case CONNECT:
|
||||
// log.Info("Recv connect message..........")
|
||||
c.ProcessConnect(buf)
|
||||
case PUBLISH:
|
||||
// log.Info("Recv publish message..........")
|
||||
c.ProcessPublish(buf)
|
||||
case PUBACK:
|
||||
//log.Info("Recv publish ack message..........")
|
||||
c.ProcessPubAck(buf)
|
||||
case PUBCOMP:
|
||||
//log.Info("Recv publish ack message..........")
|
||||
c.ProcessPubComp(buf)
|
||||
case PUBREC:
|
||||
//log.Info("Recv publish rec message..........")
|
||||
c.ProcessPubREC(buf)
|
||||
case PUBREL:
|
||||
//log.Info("Recv publish rel message..........")
|
||||
c.ProcessPubREL(buf)
|
||||
case SUBSCRIBE:
|
||||
// log.Info("Recv subscribe message.....")
|
||||
c.ProcessSubscribe(buf)
|
||||
case SUBACK:
|
||||
// log.Info("Recv suback message.....")
|
||||
case UNSUBSCRIBE:
|
||||
// log.Info("Recv unsubscribe message.....")
|
||||
c.ProcessUnSubscribe(buf)
|
||||
case UNSUBACK:
|
||||
//log.Info("Recv unsuback message.....")
|
||||
case PINGREQ:
|
||||
// log.Info("Recv PINGREQ message..........")
|
||||
c.ProcessPing(buf)
|
||||
case PINGRESP:
|
||||
//log.Info("Recv PINGRESP message..........")
|
||||
case DISCONNECT:
|
||||
// log.Info("Recv DISCONNECT message.......")
|
||||
|
||||
if c.typ == CLIENT {
|
||||
log.Debug("Recv message:", zap.String("message type", reflect.TypeOf(msg.packet).String()[9:]), zap.String("ClientID", c.info.clientID))
|
||||
}
|
||||
|
||||
switch ca.(type) {
|
||||
case *packets.ConnackPacket:
|
||||
case *packets.ConnectPacket:
|
||||
case *packets.PublishPacket:
|
||||
packet := ca.(*packets.PublishPacket)
|
||||
c.ProcessPublish(packet)
|
||||
case *packets.PubackPacket:
|
||||
case *packets.PubrecPacket:
|
||||
case *packets.PubrelPacket:
|
||||
case *packets.PubcompPacket:
|
||||
case *packets.SubscribePacket:
|
||||
packet := ca.(*packets.SubscribePacket)
|
||||
c.ProcessSubscribe(packet)
|
||||
case *packets.SubackPacket:
|
||||
case *packets.UnsubscribePacket:
|
||||
packet := ca.(*packets.UnsubscribePacket)
|
||||
c.ProcessUnSubscribe(packet)
|
||||
case *packets.UnsubackPacket:
|
||||
case *packets.PingreqPacket:
|
||||
c.ProcessPing()
|
||||
case *packets.PingrespPacket:
|
||||
case *packets.DisconnectPacket:
|
||||
c.Close()
|
||||
default:
|
||||
log.Info("Recv Unknow message.......")
|
||||
log.Info("Recv Unknow message.......", zap.String("ClientID", c.info.clientID))
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ProcessConnect(buf []byte) {
|
||||
func (c *client) ProcessPublish(packet *packets.PublishPacket) {
|
||||
switch c.typ {
|
||||
case CLIENT:
|
||||
c.processClientPublish(packet)
|
||||
case ROUTER:
|
||||
c.processRouterPublish(packet)
|
||||
case CLUSTER:
|
||||
c.processRemotePublish(packet)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessConnAck(buf []byte) {
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessPublish(buf []byte) {
|
||||
msg, err := DecodePublishMessage(buf)
|
||||
if err != nil {
|
||||
log.Error("Decode Publish Message error: ", err)
|
||||
c.Close()
|
||||
func (c *client) processRemotePublish(packet *packets.PublishPacket) {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
topic := msg.Topic()
|
||||
|
||||
if c.typ != CLIENT || !c.CheckTopicAuth(PUB, string(topic)) {
|
||||
topic := packet.TopicName
|
||||
if topic == BrokerInfoTopic {
|
||||
c.ProcessInfo(packet)
|
||||
return
|
||||
}
|
||||
c.ProcessPublishMessage(buf, msg)
|
||||
|
||||
if msg.Retain() {
|
||||
if b := c.broker; b != nil {
|
||||
err := b.rl.Insert(topic, buf)
|
||||
if err != nil {
|
||||
log.Error("Insert Retain Message error: ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) processRouterPublish(packet *packets.PublishPacket) {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
|
||||
switch packet.Qos {
|
||||
case QosAtMostOnce:
|
||||
c.ProcessPublishMessage(packet)
|
||||
case QosAtLeastOnce:
|
||||
puback := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
|
||||
puback.MessageID = packet.MessageID
|
||||
if err := c.WriterPacket(puback); err != nil {
|
||||
log.Error("send puback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
c.ProcessPublishMessage(packet)
|
||||
case QosExactlyOnce:
|
||||
return
|
||||
default:
|
||||
log.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessPublishMessage(buf []byte, msg *message.PublishMessage) {
|
||||
func (c *client) processClientPublish(packet *packets.PublishPacket) {
|
||||
|
||||
topic := packet.TopicName
|
||||
|
||||
if !c.broker.CheckTopicAuth(PUB, c.info.username, topic) {
|
||||
log.Error("Pub Topics Auth failed, ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
|
||||
//publish kafka
|
||||
c.broker.Publish(&bridge.Elements{
|
||||
ClientID: c.info.clientID,
|
||||
Username: c.info.username,
|
||||
Action: bridge.Publish,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Payload: string(packet.Payload),
|
||||
Topic: topic,
|
||||
})
|
||||
|
||||
switch packet.Qos {
|
||||
case QosAtMostOnce:
|
||||
c.ProcessPublishMessage(packet)
|
||||
case QosAtLeastOnce:
|
||||
puback := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
|
||||
puback.MessageID = packet.MessageID
|
||||
if err := c.WriterPacket(puback); err != nil {
|
||||
log.Error("send puback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
c.ProcessPublishMessage(packet)
|
||||
case QosExactlyOnce:
|
||||
return
|
||||
default:
|
||||
log.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessPublishMessage(packet *packets.PublishPacket) {
|
||||
|
||||
b := c.broker
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
typ := c.typ
|
||||
topic := string(msg.Topic())
|
||||
|
||||
r := b.sl.Match(topic)
|
||||
// log.Info("psubs num: ", len(r.psubs))
|
||||
if len(r.qsubs) == 0 && len(r.psubs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for _, sub := range r.psubs {
|
||||
if sub.client.typ == ROUTER {
|
||||
if typ == ROUTER {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if sub != nil {
|
||||
err := sub.client.writeBuffer(buf)
|
||||
if err != nil {
|
||||
log.Error("process message for psub error, ", err)
|
||||
}
|
||||
if packet.Retain {
|
||||
if err := c.topicsMgr.Retain(packet); err != nil {
|
||||
log.Error("Error retaining message: ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
}
|
||||
}
|
||||
|
||||
for i, sub := range r.qsubs {
|
||||
if sub.client.typ == ROUTER {
|
||||
if typ == ROUTER {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// s.qmu.Lock()
|
||||
if cnt, exist := b.queues[string(sub.topic)]; exist && i == cnt {
|
||||
if sub != nil {
|
||||
err := sub.client.writeBuffer(buf)
|
||||
if err != nil {
|
||||
log.Error("process will message for qsub error, ", err)
|
||||
}
|
||||
}
|
||||
b.queues[topic] = (b.queues[topic] + 1) % len(r.qsubs)
|
||||
break
|
||||
}
|
||||
// s.qmu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ProcessPubAck(buf []byte) {
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessPubREC(buf []byte) {
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessPubREL(buf []byte) {
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessPubComp(buf []byte) {
|
||||
|
||||
}
|
||||
|
||||
func (c *client) ProcessSubscribe(buf []byte) {
|
||||
b := c.broker
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
msg, err := DecodeSubscribeMessage(buf)
|
||||
err := c.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &c.subs, &c.qoss)
|
||||
if err != nil {
|
||||
log.Error("Decode Subscribe Message error: ", err)
|
||||
c.Close()
|
||||
log.Error("Error retrieving subscribers list: ", zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
topics := msg.Topics()
|
||||
qos := msg.Qos()
|
||||
|
||||
suback := message.NewSubackMessage()
|
||||
suback.SetPacketId(msg.PacketId())
|
||||
var retcodes []byte
|
||||
// fmt.Println("psubs num: ", len(c.subs))
|
||||
if len(c.subs) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
for i, t := range topics {
|
||||
topic := string(t)
|
||||
//check topic auth for client
|
||||
if c.typ == CLIENT {
|
||||
if !c.CheckTopicAuth(SUB, topic) {
|
||||
log.Error("CheckSubAuth failed")
|
||||
retcodes = append(retcodes, message.QosFailure)
|
||||
continue
|
||||
}
|
||||
}
|
||||
if _, exist := c.subs[topic]; !exist {
|
||||
queue := false
|
||||
if strings.HasPrefix(topic, "$queue/") {
|
||||
if len(t) > 7 {
|
||||
t = t[7:]
|
||||
queue = true
|
||||
// b.qmu.Lock()
|
||||
if _, exists := b.queues[topic]; !exists {
|
||||
b.queues[topic] = 0
|
||||
}
|
||||
// b.qmu.Unlock()
|
||||
} else {
|
||||
retcodes = append(retcodes, message.QosFailure)
|
||||
var qsub []int
|
||||
for i, sub := range c.subs {
|
||||
s, ok := sub.(*subscription)
|
||||
if ok {
|
||||
if s.client.typ == ROUTER {
|
||||
if typ != CLIENT {
|
||||
continue
|
||||
}
|
||||
}
|
||||
sub := &subscription{
|
||||
topic: t,
|
||||
qos: qos[i],
|
||||
client: c,
|
||||
queue: queue,
|
||||
if s.share {
|
||||
qsub = append(qsub, i)
|
||||
} else {
|
||||
publish(s, packet)
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
c.subs[topic] = sub
|
||||
c.mu.Unlock()
|
||||
|
||||
err := b.sl.Insert(sub)
|
||||
if err != nil {
|
||||
log.Error("Insert subscription error: ", err)
|
||||
retcodes = append(retcodes, message.QosFailure)
|
||||
}
|
||||
retcodes = append(retcodes, qos[i])
|
||||
} else {
|
||||
//if exist ,check whether qos change
|
||||
c.subs[topic].qos = qos[i]
|
||||
retcodes = append(retcodes, qos[i])
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
if err := suback.AddReturnCodes(retcodes); err != nil {
|
||||
log.Error("add return suback code error, ", err)
|
||||
// if typ == CLIENT {
|
||||
c.Close()
|
||||
// }
|
||||
return
|
||||
if len(qsub) > 0 {
|
||||
idx := r.Intn(len(qsub))
|
||||
sub := c.subs[qsub[idx]].(*subscription)
|
||||
publish(sub, packet)
|
||||
}
|
||||
|
||||
err1 := c.writeMessage(suback)
|
||||
if err1 != nil {
|
||||
log.Error("send suback error, ", err1)
|
||||
return
|
||||
}
|
||||
//broadcast subscribe message
|
||||
if c.typ == CLIENT {
|
||||
go b.BroadcastSubOrUnsubMessage(buf)
|
||||
}
|
||||
}
|
||||
|
||||
//process retain message
|
||||
for _, t := range topics {
|
||||
bufs := b.rl.Match(t)
|
||||
for _, buf := range bufs {
|
||||
log.Info("process retain message: ", string(buf))
|
||||
if buf != nil && string(buf) != "" {
|
||||
c.writeBuffer(buf)
|
||||
}
|
||||
}
|
||||
func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) {
|
||||
switch c.typ {
|
||||
case CLIENT:
|
||||
c.processClientSubscribe(packet)
|
||||
case ROUTER:
|
||||
c.processRouterSubscribe(packet)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ProcessUnSubscribe(buf []byte) {
|
||||
func (c *client) processClientSubscribe(packet *packets.SubscribePacket) {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
|
||||
b := c.broker
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
topics := packet.Topics
|
||||
qoss := packet.Qoss
|
||||
|
||||
unsub, err := DecodeUnsubscribeMessage(buf)
|
||||
suback := packets.NewControlPacket(packets.Suback).(*packets.SubackPacket)
|
||||
suback.MessageID = packet.MessageID
|
||||
var retcodes []byte
|
||||
|
||||
for i, topic := range topics {
|
||||
t := topic
|
||||
//check topic auth for client
|
||||
if !b.CheckTopicAuth(SUB, c.info.username, topic) {
|
||||
log.Error("Sub topic Auth failed: ", zap.String("topic", topic), zap.String("ClientID", c.info.clientID))
|
||||
retcodes = append(retcodes, QosFailure)
|
||||
continue
|
||||
}
|
||||
|
||||
b.Publish(&bridge.Elements{
|
||||
ClientID: c.info.clientID,
|
||||
Username: c.info.username,
|
||||
Action: bridge.Subscribe,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Topic: topic,
|
||||
})
|
||||
|
||||
groupName := ""
|
||||
share := false
|
||||
if strings.HasPrefix(topic, "$share/") {
|
||||
substr := groupCompile.FindStringSubmatch(topic)
|
||||
if len(substr) != 3 {
|
||||
retcodes = append(retcodes, QosFailure)
|
||||
continue
|
||||
}
|
||||
share = true
|
||||
groupName = substr[1]
|
||||
topic = substr[2]
|
||||
}
|
||||
|
||||
sub := &subscription{
|
||||
topic: topic,
|
||||
qos: qoss[i],
|
||||
client: c,
|
||||
share: share,
|
||||
groupName: groupName,
|
||||
}
|
||||
|
||||
rqos, err := c.topicsMgr.Subscribe([]byte(topic), qoss[i], sub)
|
||||
if err != nil {
|
||||
log.Error("subscribe error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
retcodes = append(retcodes, QosFailure)
|
||||
continue
|
||||
}
|
||||
|
||||
c.subMap[t] = sub
|
||||
|
||||
c.session.AddTopic(t, qoss[i])
|
||||
retcodes = append(retcodes, rqos)
|
||||
c.topicsMgr.Retained([]byte(topic), &c.rmsgs)
|
||||
|
||||
}
|
||||
|
||||
suback.ReturnCodes = retcodes
|
||||
|
||||
err := c.WriterPacket(suback)
|
||||
if err != nil {
|
||||
log.Error("Decode UnSubscribe Message error: ", err)
|
||||
c.Close()
|
||||
log.Error("send suback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
topics := unsub.Topics()
|
||||
//broadcast subscribe message
|
||||
go b.BroadcastSubOrUnsubMessage(packet)
|
||||
|
||||
for _, t := range topics {
|
||||
var sub *subscription
|
||||
ok := false
|
||||
//process retain message
|
||||
for _, rm := range c.rmsgs {
|
||||
if err := c.WriterPacket(rm); err != nil {
|
||||
log.Error("Error publishing retained message:", zap.Any("err", err), zap.String("ClientID", c.info.clientID))
|
||||
} else {
|
||||
log.Info("process retain message: ", zap.Any("packet", packet), zap.String("ClientID", c.info.clientID))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if sub, ok = c.subs[string(t)]; ok {
|
||||
go c.unsubscribe(sub)
|
||||
func (c *client) processRouterSubscribe(packet *packets.SubscribePacket) {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
|
||||
b := c.broker
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
topics := packet.Topics
|
||||
qoss := packet.Qoss
|
||||
|
||||
suback := packets.NewControlPacket(packets.Suback).(*packets.SubackPacket)
|
||||
suback.MessageID = packet.MessageID
|
||||
var retcodes []byte
|
||||
|
||||
for i, topic := range topics {
|
||||
t := topic
|
||||
groupName := ""
|
||||
share := false
|
||||
if strings.HasPrefix(topic, "$share/") {
|
||||
substr := groupCompile.FindStringSubmatch(topic)
|
||||
if len(substr) != 3 {
|
||||
retcodes = append(retcodes, QosFailure)
|
||||
continue
|
||||
}
|
||||
share = true
|
||||
groupName = substr[1]
|
||||
topic = substr[2]
|
||||
}
|
||||
|
||||
sub := &subscription{
|
||||
topic: topic,
|
||||
qos: qoss[i],
|
||||
client: c,
|
||||
share: share,
|
||||
groupName: groupName,
|
||||
}
|
||||
|
||||
rqos, err := c.topicsMgr.Subscribe([]byte(topic), qoss[i], sub)
|
||||
if err != nil {
|
||||
log.Error("subscribe error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
retcodes = append(retcodes, QosFailure)
|
||||
continue
|
||||
}
|
||||
|
||||
c.subMap[t] = sub
|
||||
addSubMap(c.routeSubMap, topic)
|
||||
retcodes = append(retcodes, rqos)
|
||||
}
|
||||
|
||||
suback.ReturnCodes = retcodes
|
||||
|
||||
err := c.WriterPacket(suback)
|
||||
if err != nil {
|
||||
log.Error("send suback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) {
|
||||
switch c.typ {
|
||||
case CLIENT:
|
||||
c.processClientUnSubscribe(packet)
|
||||
case ROUTER:
|
||||
c.processRouterUnSubscribe(packet)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) processRouterUnSubscribe(packet *packets.UnsubscribePacket) {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
b := c.broker
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
topics := packet.Topics
|
||||
|
||||
for _, topic := range topics {
|
||||
sub, exist := c.subMap[topic]
|
||||
if exist {
|
||||
retainNum := delSubMap(c.routeSubMap, topic)
|
||||
if retainNum > 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
c.topicsMgr.Unsubscribe([]byte(sub.topic), sub)
|
||||
delete(c.subMap, topic)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
resp := message.NewUnsubackMessage()
|
||||
resp.SetPacketId(unsub.PacketId())
|
||||
unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket)
|
||||
unsuback.MessageID = packet.MessageID
|
||||
|
||||
err1 := c.writeMessage(resp)
|
||||
if err1 != nil {
|
||||
log.Error("send ubsuback error, ", err1)
|
||||
err := c.WriterPacket(unsuback)
|
||||
if err != nil {
|
||||
log.Error("send unsuback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) processClientUnSubscribe(packet *packets.UnsubscribePacket) {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
b := c.broker
|
||||
if b == nil {
|
||||
return
|
||||
}
|
||||
topics := packet.Topics
|
||||
|
||||
for _, topic := range topics {
|
||||
{
|
||||
//publish kafka
|
||||
|
||||
b.Publish(&bridge.Elements{
|
||||
ClientID: c.info.clientID,
|
||||
Username: c.info.username,
|
||||
Action: bridge.Unsubscribe,
|
||||
Timestamp: time.Now().Unix(),
|
||||
Topic: topic,
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
sub, exist := c.subMap[topic]
|
||||
if exist {
|
||||
c.topicsMgr.Unsubscribe([]byte(sub.topic), sub)
|
||||
c.session.RemoveTopic(topic)
|
||||
delete(c.subMap, topic)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket)
|
||||
unsuback.MessageID = packet.MessageID
|
||||
|
||||
err := c.WriterPacket(unsuback)
|
||||
if err != nil {
|
||||
log.Error("send unsuback error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
// //process ubsubscribe message
|
||||
if c.typ == CLIENT {
|
||||
b.BroadcastSubOrUnsubMessage(buf)
|
||||
}
|
||||
b.BroadcastSubOrUnsubMessage(packet)
|
||||
}
|
||||
|
||||
func (c *client) unsubscribe(sub *subscription) {
|
||||
|
||||
c.mu.Lock()
|
||||
delete(c.subs, string(sub.topic))
|
||||
c.mu.Unlock()
|
||||
|
||||
if c.broker != nil {
|
||||
c.broker.sl.Remove(sub)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) ProcessPing(buf []byte) {
|
||||
_, err := DecodePingreqMessage(buf)
|
||||
if err != nil {
|
||||
log.Error("Decode PingRequest Message error: ", err)
|
||||
c.Close()
|
||||
func (c *client) ProcessPing() {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
|
||||
pingRspMsg := message.NewPingrespMessage()
|
||||
err = c.writeMessage(pingRspMsg)
|
||||
resp := packets.NewControlPacket(packets.Pingresp).(*packets.PingrespPacket)
|
||||
err := c.WriterPacket(resp)
|
||||
if err != nil {
|
||||
log.Error("send PingResponse error, ", err)
|
||||
log.Error("send PingResponse error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) Close() {
|
||||
b := c.broker
|
||||
subs := c.subs
|
||||
if b != nil {
|
||||
b.removeClient(c)
|
||||
for _, sub := range subs {
|
||||
err := b.sl.Remove(sub)
|
||||
if err != nil {
|
||||
log.Error("closed client but remove sublist error, ", err)
|
||||
}
|
||||
}
|
||||
if c.info.willMsg != nil {
|
||||
b.ProcessPublishMessage(c.info.willMsg)
|
||||
}
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
|
||||
c.cancelFunc()
|
||||
|
||||
c.status = Disconnected
|
||||
//wait for message complete
|
||||
// time.Sleep(1 * time.Second)
|
||||
// c.status = Disconnected
|
||||
|
||||
b := c.broker
|
||||
b.Publish(&bridge.Elements{
|
||||
ClientID: c.info.clientID,
|
||||
Username: c.info.username,
|
||||
Action: bridge.Disconnect,
|
||||
Timestamp: time.Now().Unix(),
|
||||
})
|
||||
|
||||
if c.conn != nil {
|
||||
c.conn.Close()
|
||||
c.conn = nil
|
||||
}
|
||||
|
||||
subs := c.subMap
|
||||
|
||||
if b != nil {
|
||||
b.removeClient(c)
|
||||
for _, sub := range subs {
|
||||
err := b.topicsMgr.Unsubscribe([]byte(sub.topic), sub)
|
||||
if err != nil {
|
||||
log.Error("unsubscribe error, ", zap.Error(err), zap.String("ClientID", c.info.clientID))
|
||||
}
|
||||
}
|
||||
|
||||
if c.typ == CLIENT {
|
||||
b.BroadcastUnSubscribe(subs)
|
||||
//offline notification
|
||||
b.OnlineOfflineNotification(c.info.clientID, false)
|
||||
}
|
||||
|
||||
if c.info.willMsg != nil {
|
||||
b.PublishMessage(c.info.willMsg)
|
||||
}
|
||||
|
||||
if c.typ == CLUSTER {
|
||||
b.ConnectToDiscovery()
|
||||
}
|
||||
|
||||
//do reconnect
|
||||
if c.typ == REMOTE {
|
||||
go b.connectRouter(c.route.remoteID, c.route.remoteUrl)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func WriteBuffer(conn net.Conn, buf []byte) error {
|
||||
if conn == nil {
|
||||
return errors.New("conn is nul")
|
||||
func (c *client) WriterPacket(packet packets.ControlPacket) error {
|
||||
if c.status == Disconnected {
|
||||
return nil
|
||||
}
|
||||
_, err := conn.Write(buf)
|
||||
return err
|
||||
}
|
||||
func (c *client) writeBuffer(buf []byte) error {
|
||||
|
||||
if packet == nil {
|
||||
return nil
|
||||
}
|
||||
if c.conn == nil {
|
||||
c.Close()
|
||||
return errors.New("connect lost ....")
|
||||
}
|
||||
|
||||
c.mu.Lock()
|
||||
err := WriteBuffer(c.conn, buf)
|
||||
err := packet.Write(c.conn)
|
||||
c.mu.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *client) writeMessage(msg message.Message) error {
|
||||
buf, err := EncodeMessage(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return c.writeBuffer(buf)
|
||||
}
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
package broker
|
||||
|
||||
import "sync"
|
||||
|
||||
type cMap interface {
|
||||
Set(key string, val *client)
|
||||
Get(key string) (*client, bool)
|
||||
Items() map[string]*client
|
||||
Exist(key string) bool
|
||||
Update(key string, val *client) (*client, bool)
|
||||
Count() int
|
||||
Remove(key string)
|
||||
}
|
||||
|
||||
type clientMap struct {
|
||||
items map[string]*client
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewClientMap() cMap {
|
||||
smap := &clientMap{
|
||||
items: make(map[string]*client),
|
||||
}
|
||||
return smap
|
||||
}
|
||||
|
||||
func (s *clientMap) Set(key string, val *client) {
|
||||
s.mu.Lock()
|
||||
s.items[key] = val
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *clientMap) Get(key string) (*client, bool) {
|
||||
s.mu.RLock()
|
||||
val, ok := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
return val, ok
|
||||
}
|
||||
|
||||
func (s *clientMap) Exist(key string) bool {
|
||||
s.mu.RLock()
|
||||
_, ok := s.items[key]
|
||||
s.mu.RUnlock()
|
||||
return ok
|
||||
}
|
||||
|
||||
func (s *clientMap) Update(key string, val *client) (*client, bool) {
|
||||
s.mu.Lock()
|
||||
old, ok := s.items[key]
|
||||
s.items[key] = val
|
||||
s.mu.Unlock()
|
||||
return old, ok
|
||||
}
|
||||
|
||||
func (s *clientMap) Count() int {
|
||||
s.mu.RLock()
|
||||
len := len(s.items)
|
||||
s.mu.RUnlock()
|
||||
return len
|
||||
}
|
||||
|
||||
func (s *clientMap) Remove(key string) {
|
||||
s.mu.Lock()
|
||||
delete(s.items, key)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *clientMap) Items() map[string]*client {
|
||||
s.mu.RLock()
|
||||
items := s.items
|
||||
s.mu.RUnlock()
|
||||
return items
|
||||
}
|
||||
139
broker/comm.go
139
broker/comm.go
@@ -1,16 +1,17 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package broker
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/md5"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"io"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"go.uber.org/zap"
|
||||
|
||||
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||
uuid "github.com/satori/go.uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,54 +41,12 @@ const (
|
||||
PINGRESP
|
||||
DISCONNECT
|
||||
)
|
||||
|
||||
func SubscribeTopicCheckAndSpilt(subject []byte) ([]string, error) {
|
||||
|
||||
topic := string(subject)
|
||||
|
||||
if bytes.IndexByte(subject, '#') != -1 {
|
||||
if bytes.IndexByte(subject, '#') != len(subject)-1 {
|
||||
return nil, errors.New("Topic format error with index of #")
|
||||
}
|
||||
}
|
||||
|
||||
re := strings.Split(topic, "/")
|
||||
for i, v := range re {
|
||||
if i != 0 && i != (len(re)-1) {
|
||||
if v == "" {
|
||||
return nil, errors.New("Topic format error with index of //")
|
||||
}
|
||||
if strings.Contains(v, "+") && v != "+" {
|
||||
return nil, errors.New("Topic format error with index of +")
|
||||
}
|
||||
} else {
|
||||
if v == "" {
|
||||
re[i] = "/"
|
||||
}
|
||||
}
|
||||
}
|
||||
return re, nil
|
||||
|
||||
}
|
||||
|
||||
func PublishTopicCheckAndSpilt(subject []byte) ([]string, error) {
|
||||
if bytes.IndexByte(subject, '#') != -1 || bytes.IndexByte(subject, '+') != -1 {
|
||||
return nil, errors.New("Publish Topic format error with + and #")
|
||||
}
|
||||
topic := string(subject)
|
||||
re := strings.Split(topic, "/")
|
||||
for i, v := range re {
|
||||
if v == "" {
|
||||
if i != 0 && i != (len(re)-1) {
|
||||
return nil, errors.New("Topic format error with index of //")
|
||||
} else {
|
||||
re[i] = "/"
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
return re, nil
|
||||
}
|
||||
const (
|
||||
QosAtMostOnce byte = iota
|
||||
QosAtLeastOnce
|
||||
QosExactlyOnce
|
||||
QosFailure = 0x80
|
||||
)
|
||||
|
||||
func equal(k1, k2 interface{}) bool {
|
||||
if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
|
||||
@@ -134,13 +93,65 @@ func equal(k1, k2 interface{}) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func GenUniqueId() string {
|
||||
b := make([]byte, 48)
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
return ""
|
||||
func addSubMap(m map[string]uint64, topic string) {
|
||||
subNum, exist := m[topic]
|
||||
if exist {
|
||||
m[topic] = subNum + 1
|
||||
} else {
|
||||
m[topic] = 1
|
||||
}
|
||||
}
|
||||
|
||||
func delSubMap(m map[string]uint64, topic string) uint64 {
|
||||
subNum, exist := m[topic]
|
||||
if exist {
|
||||
if subNum > 1 {
|
||||
m[topic] = subNum - 1
|
||||
return subNum - 1
|
||||
}
|
||||
} else {
|
||||
m[topic] = 0
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func GenUniqueId() string {
|
||||
return uuid.NewV4().String()
|
||||
}
|
||||
|
||||
func wrapPublishPacket(packet *packets.PublishPacket) *packets.PublishPacket {
|
||||
p := packet.Copy()
|
||||
wrapPayload := map[string]interface{}{
|
||||
"message_id": GenUniqueId(),
|
||||
"payload": string(p.Payload),
|
||||
}
|
||||
b, _ := json.Marshal(wrapPayload)
|
||||
p.Payload = b
|
||||
return p
|
||||
}
|
||||
|
||||
func unWrapPublishPacket(packet *packets.PublishPacket) *packets.PublishPacket {
|
||||
p := packet.Copy()
|
||||
if gjson.GetBytes(p.Payload, "payload").Exists() {
|
||||
p.Payload = []byte(gjson.GetBytes(p.Payload, "payload").String())
|
||||
}
|
||||
return p
|
||||
}
|
||||
|
||||
func publish(sub *subscription, packet *packets.PublishPacket) {
|
||||
// var p *packets.PublishPacket
|
||||
// if sub.client.info.username != "root" {
|
||||
// p = unWrapPublishPacket(packet)
|
||||
// } else {
|
||||
// p = wrapPublishPacket(packet)
|
||||
// }
|
||||
// err := sub.client.WriterPacket(p)
|
||||
// if err != nil {
|
||||
// log.Error("process message for psub error, ", zap.Error(err))
|
||||
// }
|
||||
|
||||
err := sub.client.WriterPacket(packet)
|
||||
if err != nil {
|
||||
log.Error("process message for psub error, ", zap.Error(err))
|
||||
}
|
||||
h := md5.New()
|
||||
h.Write([]byte(base64.URLEncoding.EncodeToString(b)))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
// return GetMd5String()
|
||||
}
|
||||
|
||||
141
broker/config.go
141
broker/config.go
@@ -1,23 +1,27 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package broker
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
)
|
||||
|
||||
const (
|
||||
CONFIGFILE = "broker.config"
|
||||
"github.com/fhmq/hmq/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Worker int `json:"workerNum"`
|
||||
Host string `json:"host"`
|
||||
Port string `json:"port"`
|
||||
Cluster RouteInfo `json:"cluster"`
|
||||
Router string `json:"router"`
|
||||
TlsHost string `json:"tlsHost"`
|
||||
TlsPort string `json:"tlsPort"`
|
||||
WsPath string `json:"wsPath"`
|
||||
@@ -26,12 +30,18 @@ type Config struct {
|
||||
TlsInfo TLSInfo `json:"tlsInfo"`
|
||||
Acl bool `json:"acl"`
|
||||
AclConf string `json:"aclConf"`
|
||||
Debug bool `json:"debug"`
|
||||
Plugin Plugins `json:"plugins"`
|
||||
}
|
||||
|
||||
type Plugins struct {
|
||||
Auth string
|
||||
Bridge string
|
||||
}
|
||||
|
||||
type RouteInfo struct {
|
||||
Host string `json:"host"`
|
||||
Port string `json:"port"`
|
||||
Routes []string `json:"routes"`
|
||||
Host string `json:"host"`
|
||||
Port string `json:"port"`
|
||||
}
|
||||
|
||||
type TLSInfo struct {
|
||||
@@ -41,11 +51,95 @@ type TLSInfo struct {
|
||||
KeyFile string `json:"keyFile"`
|
||||
}
|
||||
|
||||
func LoadConfig() (*Config, error) {
|
||||
var DefaultConfig *Config = &Config{
|
||||
Worker: 4096,
|
||||
Host: "0.0.0.0",
|
||||
Port: "1883",
|
||||
Acl: false,
|
||||
}
|
||||
|
||||
content, err := ioutil.ReadFile(CONFIGFILE)
|
||||
var (
|
||||
log = logger.Prod().Named("broker")
|
||||
)
|
||||
|
||||
func showHelp() {
|
||||
fmt.Printf("%s\n", usageStr)
|
||||
os.Exit(0)
|
||||
}
|
||||
|
||||
func ConfigureConfig(args []string) (*Config, error) {
|
||||
config := &Config{}
|
||||
var (
|
||||
help bool
|
||||
configFile string
|
||||
)
|
||||
fs := flag.NewFlagSet("hmq-broker", flag.ExitOnError)
|
||||
fs.Usage = showHelp
|
||||
|
||||
fs.BoolVar(&help, "h", false, "Show this message.")
|
||||
fs.BoolVar(&help, "help", false, "Show this message.")
|
||||
fs.IntVar(&config.Worker, "w", 1024, "worker num to process message, perfer (client num)/10.")
|
||||
fs.IntVar(&config.Worker, "worker", 1024, "worker num to process message, perfer (client num)/10.")
|
||||
fs.StringVar(&config.Port, "port", "1883", "Port to listen on.")
|
||||
fs.StringVar(&config.Port, "p", "1883", "Port to listen on.")
|
||||
fs.StringVar(&config.Host, "host", "0.0.0.0", "Network host to listen on")
|
||||
fs.StringVar(&config.Cluster.Port, "cp", "", "Cluster port from which members can connect.")
|
||||
fs.StringVar(&config.Cluster.Port, "clusterport", "", "Cluster port from which members can connect.")
|
||||
fs.StringVar(&config.Router, "r", "", "Router who maintenance cluster info")
|
||||
fs.StringVar(&config.Router, "router", "", "Router who maintenance cluster info")
|
||||
fs.StringVar(&config.WsPort, "ws", "", "port for ws to listen on")
|
||||
fs.StringVar(&config.WsPort, "wsport", "", "port for ws to listen on")
|
||||
fs.StringVar(&config.WsPath, "wsp", "", "path for ws to listen on")
|
||||
fs.StringVar(&config.WsPath, "wspath", "", "path for ws to listen on")
|
||||
fs.StringVar(&configFile, "config", "", "config file for hmq")
|
||||
fs.StringVar(&configFile, "c", "", "config file for hmq")
|
||||
fs.BoolVar(&config.Debug, "debug", false, "enable Debug logging.")
|
||||
fs.BoolVar(&config.Debug, "d", false, "enable Debug logging.")
|
||||
|
||||
fs.Bool("D", true, "enable Debug logging.")
|
||||
|
||||
if err := fs.Parse(args); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if help {
|
||||
showHelp()
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
fs.Visit(func(f *flag.Flag) {
|
||||
switch f.Name {
|
||||
case "D":
|
||||
config.Debug = true
|
||||
}
|
||||
})
|
||||
|
||||
if configFile != "" {
|
||||
tmpConfig, e := LoadConfig(configFile)
|
||||
if e != nil {
|
||||
return nil, e
|
||||
} else {
|
||||
config = tmpConfig
|
||||
}
|
||||
}
|
||||
|
||||
if config.Debug {
|
||||
log = logger.Debug().Named("broker")
|
||||
}
|
||||
|
||||
if err := config.check(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
|
||||
}
|
||||
|
||||
func LoadConfig(filename string) (*Config, error) {
|
||||
|
||||
content, err := ioutil.ReadFile(filename)
|
||||
if err != nil {
|
||||
log.Error("Read config file error: ", err)
|
||||
// log.Error("Read config file error: ", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
// log.Info(string(content))
|
||||
@@ -53,10 +147,19 @@ func LoadConfig() (*Config, error) {
|
||||
var config Config
|
||||
err = json.Unmarshal(content, &config)
|
||||
if err != nil {
|
||||
log.Error("Unmarshal config file error: ", err)
|
||||
// log.Error("Unmarshal config file error: ", zap.Error(err))
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
}
|
||||
|
||||
func (config *Config) check() error {
|
||||
|
||||
if config.Worker == 0 {
|
||||
config.Worker = 1024
|
||||
}
|
||||
|
||||
if config.Port != "" {
|
||||
if config.Host == "" {
|
||||
config.Host = "0.0.0.0"
|
||||
@@ -68,29 +171,33 @@ func LoadConfig() (*Config, error) {
|
||||
config.Cluster.Host = "0.0.0.0"
|
||||
}
|
||||
}
|
||||
if config.Router != "" {
|
||||
if config.Cluster.Port == "" {
|
||||
return errors.New("cluster port is null")
|
||||
}
|
||||
}
|
||||
|
||||
if config.TlsPort != "" {
|
||||
if config.TlsInfo.CertFile == "" || config.TlsInfo.KeyFile == "" {
|
||||
log.Error("tls config error, no cert or key file.")
|
||||
return nil, err
|
||||
return errors.New("tls config error, no cert or key file.")
|
||||
}
|
||||
if config.TlsHost == "" {
|
||||
config.TlsHost = "0.0.0.0"
|
||||
}
|
||||
}
|
||||
|
||||
return &config, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewTLSConfig(tlsInfo TLSInfo) (*tls.Config, error) {
|
||||
|
||||
cert, err := tls.LoadX509KeyPair(tlsInfo.CertFile, tlsInfo.KeyFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", err)
|
||||
return nil, fmt.Errorf("error parsing X509 certificate/key pair: %v", zap.Error(err))
|
||||
}
|
||||
cert.Leaf, err = x509.ParseCertificate(cert.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", err)
|
||||
return nil, fmt.Errorf("error parsing certificate: %v", zap.Error(err))
|
||||
}
|
||||
|
||||
// Create TLSConfig
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package broker
|
||||
|
||||
const (
|
||||
WorkNum = 2048
|
||||
)
|
||||
|
||||
type Dispatcher struct {
|
||||
WorkerPool chan chan *Message
|
||||
}
|
||||
|
||||
func init() {
|
||||
InitMessagePool()
|
||||
dispatcher := NewDispatcher()
|
||||
dispatcher.Run()
|
||||
}
|
||||
|
||||
func (d *Dispatcher) Run() {
|
||||
// starting n number of workers
|
||||
for i := 0; i < WorkNum; i++ {
|
||||
worker := NewWorker(d.WorkerPool)
|
||||
worker.Start()
|
||||
}
|
||||
go d.dispatch()
|
||||
}
|
||||
|
||||
func NewDispatcher() *Dispatcher {
|
||||
pool := make(chan chan *Message, WorkNum)
|
||||
return &Dispatcher{WorkerPool: pool}
|
||||
}
|
||||
|
||||
func (d *Dispatcher) dispatch() {
|
||||
for i := 0; i < MessagePoolNum; i++ {
|
||||
go func(idx int) {
|
||||
for {
|
||||
select {
|
||||
case msg := <-MSGPool[idx].queue:
|
||||
go func(msg *Message) {
|
||||
msgChannel := <-d.WorkerPool
|
||||
msgChannel <- msg
|
||||
}(msg)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
}
|
||||
26
broker/http.go
Normal file
26
broker/http.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func InitHTTPMoniter(b *Broker) {
|
||||
gin.SetMode(gin.ReleaseMode)
|
||||
router := gin.Default()
|
||||
router.DELETE("api/v1/connections/:clientid", func(c *gin.Context) {
|
||||
clientid := c.Param("clientid")
|
||||
cli, ok := b.clients.Load(clientid)
|
||||
if ok {
|
||||
conn, succss := cli.(*client)
|
||||
if succss {
|
||||
conn.Close()
|
||||
}
|
||||
}
|
||||
resp := map[string]int{
|
||||
"code": 0,
|
||||
}
|
||||
c.JSON(200, &resp)
|
||||
})
|
||||
|
||||
router.Run(":8080")
|
||||
}
|
||||
114
broker/info.go
114
broker/info.go
@@ -1,113 +1,113 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package broker
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hmq/lib/message"
|
||||
"time"
|
||||
|
||||
simplejson "github.com/bitly/go-simplejson"
|
||||
log "github.com/cihub/seelog"
|
||||
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func (c *client) SendInfo() {
|
||||
if c.status == Disconnected {
|
||||
return
|
||||
}
|
||||
url := c.info.localIP + ":" + c.broker.config.Cluster.Port
|
||||
|
||||
infoMsg := NewInfo(c.broker.id, url, false)
|
||||
err := c.writeMessage(infoMsg)
|
||||
err := c.WriterPacket(infoMsg)
|
||||
if err != nil {
|
||||
log.Error("send info message error, ", err)
|
||||
log.Error("send info message error, ", zap.Error(err))
|
||||
return
|
||||
}
|
||||
// log.Info("send info success")
|
||||
}
|
||||
|
||||
func (c *client) StartPing() {
|
||||
timeTicker := time.NewTicker(time.Second * 30)
|
||||
ping := message.NewPingreqMessage()
|
||||
timeTicker := time.NewTicker(time.Second * 50)
|
||||
ping := packets.NewControlPacket(packets.Pingreq).(*packets.PingreqPacket)
|
||||
for {
|
||||
select {
|
||||
case <-timeTicker.C:
|
||||
err := c.writeMessage(ping)
|
||||
err := c.WriterPacket(ping)
|
||||
if err != nil {
|
||||
log.Error("ping error: ", err)
|
||||
log.Error("ping error: ", zap.Error(err))
|
||||
c.Close()
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *client) SendConnect() {
|
||||
|
||||
clientID := c.info.clientID
|
||||
connMsg := message.NewConnectMessage()
|
||||
connMsg.SetClientId(clientID)
|
||||
connMsg.SetVersion(0x04)
|
||||
err := c.writeMessage(connMsg)
|
||||
if err != nil {
|
||||
log.Error("send connect message error, ", err)
|
||||
if c.status != Connected {
|
||||
return
|
||||
}
|
||||
// log.Info("send connet success")
|
||||
m := packets.NewControlPacket(packets.Connect).(*packets.ConnectPacket)
|
||||
|
||||
m.CleanSession = true
|
||||
m.ClientIdentifier = c.info.clientID
|
||||
m.Keepalive = uint16(60)
|
||||
err := c.WriterPacket(m)
|
||||
if err != nil {
|
||||
log.Error("send connect message error, ", zap.Error(err))
|
||||
return
|
||||
}
|
||||
log.Info("send connect success")
|
||||
}
|
||||
|
||||
func NewInfo(sid, url string, isforword bool) *message.PublishMessage {
|
||||
infoMsg := message.NewPublishMessage()
|
||||
infoMsg.SetTopic([]byte(BrokerInfoTopic))
|
||||
info := fmt.Sprintf(`{"remoteID":"%s","url":"%s","isForward":%t}`, sid, url, isforword)
|
||||
func NewInfo(sid, url string, isforword bool) *packets.PublishPacket {
|
||||
pub := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||
pub.Qos = 0
|
||||
pub.TopicName = BrokerInfoTopic
|
||||
pub.Retain = false
|
||||
info := fmt.Sprintf(`{"brokerID":"%s","brokerUrl":"%s"}`, sid, url)
|
||||
// log.Info("new info", string(info))
|
||||
infoMsg.SetPayload([]byte(info))
|
||||
infoMsg.SetQoS(0)
|
||||
infoMsg.SetRetain(false)
|
||||
return infoMsg
|
||||
pub.Payload = []byte(info)
|
||||
return pub
|
||||
}
|
||||
|
||||
func (c *client) ProcessInfo(msg *message.PublishMessage) {
|
||||
func (c *client) ProcessInfo(packet *packets.PublishPacket) {
|
||||
nc := c.conn
|
||||
b := c.broker
|
||||
if nc == nil {
|
||||
return
|
||||
}
|
||||
|
||||
log.Info("recv remoteInfo: ", string(msg.Payload()))
|
||||
log.Info("recv remoteInfo: ", zap.String("payload", string(packet.Payload)))
|
||||
|
||||
js, e := simplejson.NewJson(msg.Payload())
|
||||
if e != nil {
|
||||
log.Warn("parse info message err", e)
|
||||
js, err := simplejson.NewJson(packet.Payload)
|
||||
if err != nil {
|
||||
log.Warn("parse info message err", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
rid := js.Get("remoteID").MustString()
|
||||
rurl := js.Get("url").MustString()
|
||||
isForward := js.Get("isForward").MustBool()
|
||||
|
||||
if rid == "" {
|
||||
log.Error("receive info message error with remoteID is null")
|
||||
routes, err := js.Get("data").Map()
|
||||
if routes == nil {
|
||||
log.Error("receive info message error, ", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
if rid == b.id {
|
||||
if !isForward {
|
||||
c.Close() //close connet self
|
||||
b.nodes = routes
|
||||
|
||||
b.mu.Lock()
|
||||
for rid, rurl := range routes {
|
||||
if rid == b.id {
|
||||
continue
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
exist := b.CheckRemoteExist(rid, rurl)
|
||||
if !exist {
|
||||
go b.connectRouter(rurl, rid)
|
||||
}
|
||||
// log.Info("isforword: ", isForward)
|
||||
if !isForward {
|
||||
route := &route{
|
||||
remoteUrl: rurl,
|
||||
remoteID: rid,
|
||||
url, ok := rurl.(string)
|
||||
if ok {
|
||||
exist := b.CheckRemoteExist(rid, url)
|
||||
if !exist {
|
||||
b.connectRouter(rid, url)
|
||||
}
|
||||
}
|
||||
c.route = route
|
||||
|
||||
go b.SendLocalSubsToRouter(c)
|
||||
// log.Info("BroadcastInfoMessage starting... ")
|
||||
infoMsg := NewInfo(rid, rurl, true)
|
||||
b.BroadcastInfoMessage(rid, infoMsg)
|
||||
}
|
||||
|
||||
return
|
||||
b.mu.Unlock()
|
||||
}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package acl
|
||||
|
||||
import (
|
||||
@@ -1,3 +1,4 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>*/
|
||||
package acl
|
||||
|
||||
import "strings"
|
||||
@@ -1,3 +1,5 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package acl
|
||||
|
||||
import (
|
||||
62
broker/lib/sessions/memprovider.go
Normal file
62
broker/lib/sessions/memprovider.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var _ SessionsProvider = (*memProvider)(nil)
|
||||
|
||||
func init() {
|
||||
Register("mem", NewMemProvider())
|
||||
}
|
||||
|
||||
type memProvider struct {
|
||||
st map[string]*Session
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMemProvider() *memProvider {
|
||||
return &memProvider{
|
||||
st: make(map[string]*Session),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *memProvider) New(id string) (*Session, error) {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
|
||||
this.st[id] = &Session{id: id}
|
||||
return this.st[id], nil
|
||||
}
|
||||
|
||||
func (this *memProvider) Get(id string) (*Session, error) {
|
||||
this.mu.RLock()
|
||||
defer this.mu.RUnlock()
|
||||
|
||||
sess, ok := this.st[id]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("store/Get: No session found for key %s", id)
|
||||
}
|
||||
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (this *memProvider) Del(id string) {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
delete(this.st, id)
|
||||
}
|
||||
|
||||
func (this *memProvider) Save(id string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *memProvider) Count() int {
|
||||
return len(this.st)
|
||||
}
|
||||
|
||||
func (this *memProvider) Close() error {
|
||||
this.st = make(map[string]*Session)
|
||||
return nil
|
||||
}
|
||||
149
broker/lib/sessions/session.go
Normal file
149
broker/lib/sessions/session.go
Normal file
@@ -0,0 +1,149 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
// Queue size for the ack queue
|
||||
defaultQueueSize = 16
|
||||
)
|
||||
|
||||
type Session struct {
|
||||
|
||||
// cmsg is the CONNECT message
|
||||
cmsg *packets.ConnectPacket
|
||||
|
||||
// Will message to publish if connect is closed unexpectedly
|
||||
Will *packets.PublishPacket
|
||||
|
||||
// Retained publish message
|
||||
Retained *packets.PublishPacket
|
||||
|
||||
// topics stores all the topis for this session/client
|
||||
topics map[string]byte
|
||||
|
||||
// Initialized?
|
||||
initted bool
|
||||
|
||||
// Serialize access to this session
|
||||
mu sync.Mutex
|
||||
|
||||
id string
|
||||
}
|
||||
|
||||
func (this *Session) Init(msg *packets.ConnectPacket) error {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
|
||||
if this.initted {
|
||||
return fmt.Errorf("Session already initialized")
|
||||
}
|
||||
|
||||
this.cmsg = msg
|
||||
|
||||
if this.cmsg.WillFlag {
|
||||
this.Will = packets.NewControlPacket(packets.Publish).(*packets.PublishPacket)
|
||||
this.Will.Qos = this.cmsg.Qos
|
||||
this.Will.TopicName = this.cmsg.WillTopic
|
||||
this.Will.Payload = this.cmsg.WillMessage
|
||||
this.Will.Retain = this.cmsg.WillRetain
|
||||
}
|
||||
|
||||
this.topics = make(map[string]byte, 1)
|
||||
|
||||
this.id = string(msg.ClientIdentifier)
|
||||
|
||||
this.initted = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Session) Update(msg *packets.ConnectPacket) error {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
|
||||
this.cmsg = msg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Session) RetainMessage(msg *packets.PublishPacket) error {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
|
||||
this.Retained = msg
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Session) AddTopic(topic string, qos byte) error {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
|
||||
if !this.initted {
|
||||
return fmt.Errorf("Session not yet initialized")
|
||||
}
|
||||
|
||||
this.topics[topic] = qos
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Session) RemoveTopic(topic string) error {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
|
||||
if !this.initted {
|
||||
return fmt.Errorf("Session not yet initialized")
|
||||
}
|
||||
|
||||
delete(this.topics, topic)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *Session) Topics() ([]string, []byte, error) {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
|
||||
if !this.initted {
|
||||
return nil, nil, fmt.Errorf("Session not yet initialized")
|
||||
}
|
||||
|
||||
var (
|
||||
topics []string
|
||||
qoss []byte
|
||||
)
|
||||
|
||||
for k, v := range this.topics {
|
||||
topics = append(topics, k)
|
||||
qoss = append(qoss, v)
|
||||
}
|
||||
|
||||
return topics, qoss, nil
|
||||
}
|
||||
|
||||
func (this *Session) ID() string {
|
||||
return this.cmsg.ClientIdentifier
|
||||
}
|
||||
|
||||
func (this *Session) WillFlag() bool {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
return this.cmsg.WillFlag
|
||||
}
|
||||
|
||||
func (this *Session) SetWillFlag(v bool) {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
this.cmsg.WillFlag = v
|
||||
}
|
||||
|
||||
func (this *Session) CleanSession() bool {
|
||||
this.mu.Lock()
|
||||
defer this.mu.Unlock()
|
||||
return this.cmsg.CleanSession
|
||||
}
|
||||
92
broker/lib/sessions/sessions.go
Normal file
92
broker/lib/sessions/sessions.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package sessions
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrSessionsProviderNotFound = errors.New("Session: Session provider not found")
|
||||
ErrKeyNotAvailable = errors.New("Session: not item found for key.")
|
||||
|
||||
providers = make(map[string]SessionsProvider)
|
||||
)
|
||||
|
||||
type SessionsProvider interface {
|
||||
New(id string) (*Session, error)
|
||||
Get(id string) (*Session, error)
|
||||
Del(id string)
|
||||
Save(id string) error
|
||||
Count() int
|
||||
Close() error
|
||||
}
|
||||
|
||||
// Register makes a session provider available by the provided name.
|
||||
// If a Register is called twice with the same name or if the driver is nil,
|
||||
// it panics.
|
||||
func Register(name string, provider SessionsProvider) {
|
||||
if provider == nil {
|
||||
panic("session: Register provide is nil")
|
||||
}
|
||||
|
||||
if _, dup := providers[name]; dup {
|
||||
panic("session: Register called twice for provider " + name)
|
||||
}
|
||||
|
||||
providers[name] = provider
|
||||
}
|
||||
|
||||
func Unregister(name string) {
|
||||
delete(providers, name)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
p SessionsProvider
|
||||
}
|
||||
|
||||
func NewManager(providerName string) (*Manager, error) {
|
||||
p, ok := providers[providerName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session: unknown provider %q", providerName)
|
||||
}
|
||||
|
||||
return &Manager{p: p}, nil
|
||||
}
|
||||
|
||||
func (this *Manager) New(id string) (*Session, error) {
|
||||
if id == "" {
|
||||
id = this.sessionId()
|
||||
}
|
||||
return this.p.New(id)
|
||||
}
|
||||
|
||||
func (this *Manager) Get(id string) (*Session, error) {
|
||||
return this.p.Get(id)
|
||||
}
|
||||
|
||||
func (this *Manager) Del(id string) {
|
||||
this.p.Del(id)
|
||||
}
|
||||
|
||||
func (this *Manager) Save(id string) error {
|
||||
return this.p.Save(id)
|
||||
}
|
||||
|
||||
func (this *Manager) Count() int {
|
||||
return this.p.Count()
|
||||
}
|
||||
|
||||
func (this *Manager) Close() error {
|
||||
return this.p.Close()
|
||||
}
|
||||
|
||||
func (manager *Manager) sessionId() string {
|
||||
b := make([]byte, 15)
|
||||
if _, err := io.ReadFull(rand.Reader, b); err != nil {
|
||||
return ""
|
||||
}
|
||||
return base64.URLEncoding.EncodeToString(b)
|
||||
}
|
||||
549
broker/lib/topics/memtopics.go
Normal file
549
broker/lib/topics/memtopics.go
Normal file
@@ -0,0 +1,549 @@
|
||||
package topics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
QosAtMostOnce byte = iota
|
||||
QosAtLeastOnce
|
||||
QosExactlyOnce
|
||||
QosFailure = 0x80
|
||||
)
|
||||
|
||||
var _ TopicsProvider = (*memTopics)(nil)
|
||||
|
||||
type memTopics struct {
|
||||
// Sub/unsub mutex
|
||||
smu sync.RWMutex
|
||||
// Subscription tree
|
||||
sroot *snode
|
||||
|
||||
// Retained message mutex
|
||||
rmu sync.RWMutex
|
||||
// Retained messages topic tree
|
||||
rroot *rnode
|
||||
}
|
||||
|
||||
func init() {
|
||||
Register("mem", NewMemProvider())
|
||||
}
|
||||
|
||||
// NewMemProvider returns an new instance of the memTopics, which is implements the
|
||||
// TopicsProvider interface. memProvider is a hidden struct that stores the topic
|
||||
// subscriptions and retained messages in memory. The content is not persistend so
|
||||
// when the server goes, everything will be gone. Use with care.
|
||||
func NewMemProvider() *memTopics {
|
||||
return &memTopics{
|
||||
sroot: newSNode(),
|
||||
rroot: newRNode(),
|
||||
}
|
||||
}
|
||||
|
||||
func ValidQos(qos byte) bool {
|
||||
return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce
|
||||
}
|
||||
|
||||
func (this *memTopics) Subscribe(topic []byte, qos byte, sub interface{}) (byte, error) {
|
||||
if !ValidQos(qos) {
|
||||
return QosFailure, fmt.Errorf("Invalid QoS %d", qos)
|
||||
}
|
||||
|
||||
if sub == nil {
|
||||
return QosFailure, fmt.Errorf("Subscriber cannot be nil")
|
||||
}
|
||||
|
||||
this.smu.Lock()
|
||||
defer this.smu.Unlock()
|
||||
|
||||
if qos > QosExactlyOnce {
|
||||
qos = QosExactlyOnce
|
||||
}
|
||||
|
||||
if err := this.sroot.sinsert(topic, qos, sub); err != nil {
|
||||
return QosFailure, err
|
||||
}
|
||||
|
||||
return qos, nil
|
||||
}
|
||||
|
||||
func (this *memTopics) Unsubscribe(topic []byte, sub interface{}) error {
|
||||
this.smu.Lock()
|
||||
defer this.smu.Unlock()
|
||||
|
||||
return this.sroot.sremove(topic, sub)
|
||||
}
|
||||
|
||||
// Returned values will be invalidated by the next Subscribers call
|
||||
func (this *memTopics) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
|
||||
if !ValidQos(qos) {
|
||||
return fmt.Errorf("Invalid QoS %d", qos)
|
||||
}
|
||||
|
||||
this.smu.RLock()
|
||||
defer this.smu.RUnlock()
|
||||
|
||||
*subs = (*subs)[0:0]
|
||||
*qoss = (*qoss)[0:0]
|
||||
|
||||
return this.sroot.smatch(topic, qos, subs, qoss)
|
||||
}
|
||||
|
||||
func (this *memTopics) Retain(msg *packets.PublishPacket) error {
|
||||
this.rmu.Lock()
|
||||
defer this.rmu.Unlock()
|
||||
|
||||
// So apparently, at least according to the MQTT Conformance/Interoperability
|
||||
// Testing, that a payload of 0 means delete the retain message.
|
||||
// https://eclipse.org/paho/clients/testing/
|
||||
if len(msg.Payload) == 0 {
|
||||
return this.rroot.rremove([]byte(msg.TopicName))
|
||||
}
|
||||
|
||||
return this.rroot.rinsert([]byte(msg.TopicName), msg)
|
||||
}
|
||||
|
||||
func (this *memTopics) Retained(topic []byte, msgs *[]*packets.PublishPacket) error {
|
||||
this.rmu.RLock()
|
||||
defer this.rmu.RUnlock()
|
||||
|
||||
return this.rroot.rmatch(topic, msgs)
|
||||
}
|
||||
|
||||
func (this *memTopics) Close() error {
|
||||
this.sroot = nil
|
||||
this.rroot = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// subscrition nodes
|
||||
type snode struct {
|
||||
// If this is the end of the topic string, then add subscribers here
|
||||
subs []interface{}
|
||||
qos []byte
|
||||
|
||||
// Otherwise add the next topic level here
|
||||
snodes map[string]*snode
|
||||
}
|
||||
|
||||
func newSNode() *snode {
|
||||
return &snode{
|
||||
snodes: make(map[string]*snode),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *snode) sinsert(topic []byte, qos byte, sub interface{}) error {
|
||||
// If there's no more topic levels, that means we are at the matching snode
|
||||
// to insert the subscriber. So let's see if there's such subscriber,
|
||||
// if so, update it. Otherwise insert it.
|
||||
if len(topic) == 0 {
|
||||
// Let's see if the subscriber is already on the list. If yes, update
|
||||
// QoS and then return.
|
||||
for i := range this.subs {
|
||||
if equal(this.subs[i], sub) {
|
||||
this.qos[i] = qos
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Otherwise add.
|
||||
this.subs = append(this.subs, sub)
|
||||
this.qos = append(this.qos, qos)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Not the last level, so let's find or create the next level snode, and
|
||||
// recursively call it's insert().
|
||||
|
||||
// ntl = next topic level
|
||||
ntl, rem, err := nextTopicLevel(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
level := string(ntl)
|
||||
|
||||
// Add snode if it doesn't already exist
|
||||
n, ok := this.snodes[level]
|
||||
if !ok {
|
||||
n = newSNode()
|
||||
this.snodes[level] = n
|
||||
}
|
||||
|
||||
return n.sinsert(rem, qos, sub)
|
||||
}
|
||||
|
||||
// This remove implementation ignores the QoS, as long as the subscriber
|
||||
// matches then it's removed
|
||||
func (this *snode) sremove(topic []byte, sub interface{}) error {
|
||||
// If the topic is empty, it means we are at the final matching snode. If so,
|
||||
// let's find the matching subscribers and remove them.
|
||||
if len(topic) == 0 {
|
||||
// If subscriber == nil, then it's signal to remove ALL subscribers
|
||||
if sub == nil {
|
||||
this.subs = this.subs[0:0]
|
||||
this.qos = this.qos[0:0]
|
||||
return nil
|
||||
}
|
||||
|
||||
// If we find the subscriber then remove it from the list. Technically
|
||||
// we just overwrite the slot by shifting all other items up by one.
|
||||
for i := range this.subs {
|
||||
if equal(this.subs[i], sub) {
|
||||
this.subs = append(this.subs[:i], this.subs[i+1:]...)
|
||||
this.qos = append(this.qos[:i], this.qos[i+1:]...)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("No topic found for subscriber")
|
||||
}
|
||||
|
||||
// Not the last level, so let's find the next level snode, and recursively
|
||||
// call it's remove().
|
||||
|
||||
// ntl = next topic level
|
||||
ntl, rem, err := nextTopicLevel(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
level := string(ntl)
|
||||
|
||||
// Find the snode that matches the topic level
|
||||
n, ok := this.snodes[level]
|
||||
if !ok {
|
||||
return fmt.Errorf("No topic found")
|
||||
}
|
||||
|
||||
// Remove the subscriber from the next level snode
|
||||
if err := n.sremove(rem, sub); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are no more subscribers and snodes to the next level we just visited
|
||||
// let's remove it
|
||||
if len(n.subs) == 0 && len(n.snodes) == 0 {
|
||||
delete(this.snodes, level)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// smatch() returns all the subscribers that are subscribed to the topic. Given a topic
|
||||
// with no wildcards (publish topic), it returns a list of subscribers that subscribes
|
||||
// to the topic. For each of the level names, it's a match
|
||||
// - if there are subscribers to '#', then all the subscribers are added to result set
|
||||
func (this *snode) smatch(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
|
||||
// If the topic is empty, it means we are at the final matching snode. If so,
|
||||
// let's find the subscribers that match the qos and append them to the list.
|
||||
if len(topic) == 0 {
|
||||
this.matchQos(qos, subs, qoss)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ntl = next topic level
|
||||
ntl, rem, err := nextTopicLevel(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
level := string(ntl)
|
||||
|
||||
for k, n := range this.snodes {
|
||||
// If the key is "#", then these subscribers are added to the result set
|
||||
if k == MWC {
|
||||
n.matchQos(qos, subs, qoss)
|
||||
} else if k == SWC || k == level {
|
||||
if err := n.smatch(rem, qos, subs, qoss); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// retained message nodes
|
||||
type rnode struct {
|
||||
// If this is the end of the topic string, then add retained messages here
|
||||
msg *packets.PublishPacket
|
||||
// Otherwise add the next topic level here
|
||||
rnodes map[string]*rnode
|
||||
}
|
||||
|
||||
func newRNode() *rnode {
|
||||
return &rnode{
|
||||
rnodes: make(map[string]*rnode),
|
||||
}
|
||||
}
|
||||
|
||||
func (this *rnode) rinsert(topic []byte, msg *packets.PublishPacket) error {
|
||||
// If there's no more topic levels, that means we are at the matching rnode.
|
||||
if len(topic) == 0 {
|
||||
// Reuse the message if possible
|
||||
if this.msg == nil {
|
||||
this.msg = msg
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Not the last level, so let's find or create the next level snode, and
|
||||
// recursively call it's insert().
|
||||
|
||||
// ntl = next topic level
|
||||
ntl, rem, err := nextTopicLevel(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
level := string(ntl)
|
||||
|
||||
// Add snode if it doesn't already exist
|
||||
n, ok := this.rnodes[level]
|
||||
if !ok {
|
||||
n = newRNode()
|
||||
this.rnodes[level] = n
|
||||
}
|
||||
|
||||
return n.rinsert(rem, msg)
|
||||
}
|
||||
|
||||
// Remove the retained message for the supplied topic
|
||||
func (this *rnode) rremove(topic []byte) error {
|
||||
// If the topic is empty, it means we are at the final matching rnode. If so,
|
||||
// let's remove the buffer and message.
|
||||
if len(topic) == 0 {
|
||||
this.msg = nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// Not the last level, so let's find the next level rnode, and recursively
|
||||
// call it's remove().
|
||||
|
||||
// ntl = next topic level
|
||||
ntl, rem, err := nextTopicLevel(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
level := string(ntl)
|
||||
|
||||
// Find the rnode that matches the topic level
|
||||
n, ok := this.rnodes[level]
|
||||
if !ok {
|
||||
return fmt.Errorf("No topic found")
|
||||
}
|
||||
|
||||
// Remove the subscriber from the next level rnode
|
||||
if err := n.rremove(rem); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If there are no more rnodes to the next level we just visited let's remove it
|
||||
if len(n.rnodes) == 0 {
|
||||
delete(this.rnodes, level)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rmatch() finds the retained messages for the topic and qos provided. It's somewhat
|
||||
// of a reverse match compare to match() since the supplied topic can contain
|
||||
// wildcards, whereas the retained message topic is a full (no wildcard) topic.
|
||||
func (this *rnode) rmatch(topic []byte, msgs *[]*packets.PublishPacket) error {
|
||||
// If the topic is empty, it means we are at the final matching rnode. If so,
|
||||
// add the retained msg to the list.
|
||||
if len(topic) == 0 {
|
||||
if this.msg != nil {
|
||||
*msgs = append(*msgs, this.msg)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ntl = next topic level
|
||||
ntl, rem, err := nextTopicLevel(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
level := string(ntl)
|
||||
|
||||
if level == MWC {
|
||||
// If '#', add all retained messages starting this node
|
||||
this.allRetained(msgs)
|
||||
} else if level == SWC {
|
||||
// If '+', check all nodes at this level. Next levels must be matched.
|
||||
for _, n := range this.rnodes {
|
||||
if err := n.rmatch(rem, msgs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Otherwise, find the matching node, go to the next level
|
||||
if n, ok := this.rnodes[level]; ok {
|
||||
if err := n.rmatch(rem, msgs); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *rnode) allRetained(msgs *[]*packets.PublishPacket) {
|
||||
if this.msg != nil {
|
||||
*msgs = append(*msgs, this.msg)
|
||||
}
|
||||
|
||||
for _, n := range this.rnodes {
|
||||
n.allRetained(msgs)
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
stateCHR byte = iota // Regular character
|
||||
stateMWC // Multi-level wildcard
|
||||
stateSWC // Single-level wildcard
|
||||
stateSEP // Topic level separator
|
||||
stateSYS // System level topic ($)
|
||||
)
|
||||
|
||||
// Returns topic level, remaining topic levels and any errors
|
||||
func nextTopicLevel(topic []byte) ([]byte, []byte, error) {
|
||||
s := stateCHR
|
||||
|
||||
for i, c := range topic {
|
||||
switch c {
|
||||
case '/':
|
||||
if s == stateMWC {
|
||||
return nil, nil, fmt.Errorf("Multi-level wildcard found in topic and it's not at the last level")
|
||||
}
|
||||
|
||||
if i == 0 {
|
||||
return []byte(SWC), topic[i+1:], nil
|
||||
}
|
||||
|
||||
return topic[:i], topic[i+1:], nil
|
||||
|
||||
case '#':
|
||||
if i != 0 {
|
||||
return nil, nil, fmt.Errorf("Wildcard character '#' must occupy entire topic level")
|
||||
}
|
||||
|
||||
s = stateMWC
|
||||
|
||||
case '+':
|
||||
if i != 0 {
|
||||
return nil, nil, fmt.Errorf("Wildcard character '+' must occupy entire topic level")
|
||||
}
|
||||
|
||||
s = stateSWC
|
||||
|
||||
// case '$':
|
||||
// if i == 0 {
|
||||
// return nil, nil, fmt.Errorf("Cannot publish to $ topics")
|
||||
// }
|
||||
|
||||
// s = stateSYS
|
||||
|
||||
default:
|
||||
if s == stateMWC || s == stateSWC {
|
||||
return nil, nil, fmt.Errorf("Wildcard characters '#' and '+' must occupy entire topic level")
|
||||
}
|
||||
|
||||
s = stateCHR
|
||||
}
|
||||
}
|
||||
|
||||
// If we got here that means we didn't hit the separator along the way, so the
|
||||
// topic is either empty, or does not contain a separator. Either way, we return
|
||||
// the full topic
|
||||
return topic, nil, nil
|
||||
}
|
||||
|
||||
// The QoS of the payload messages sent in response to a subscription must be the
|
||||
// minimum of the QoS of the originally published message (in this case, it's the
|
||||
// qos parameter) and the maximum QoS granted by the server (in this case, it's
|
||||
// the QoS in the topic tree).
|
||||
//
|
||||
// It's also possible that even if the topic matches, the subscriber is not included
|
||||
// due to the QoS granted is lower than the published message QoS. For example,
|
||||
// if the client is granted only QoS 0, and the publish message is QoS 1, then this
|
||||
// client is not to be send the published message.
|
||||
func (this *snode) matchQos(qos byte, subs *[]interface{}, qoss *[]byte) {
|
||||
for _, sub := range this.subs {
|
||||
// If the published QoS is higher than the subscriber QoS, then we skip the
|
||||
// subscriber. Otherwise, add to the list.
|
||||
// if qos >= this.qos[i] {
|
||||
*subs = append(*subs, sub)
|
||||
*qoss = append(*qoss, qos)
|
||||
// }
|
||||
}
|
||||
}
|
||||
|
||||
func equal(k1, k2 interface{}) bool {
|
||||
if reflect.TypeOf(k1) != reflect.TypeOf(k2) {
|
||||
return false
|
||||
}
|
||||
|
||||
if reflect.ValueOf(k1).Kind() == reflect.Func {
|
||||
return &k1 == &k2
|
||||
}
|
||||
|
||||
if k1 == k2 {
|
||||
return true
|
||||
}
|
||||
|
||||
switch k1 := k1.(type) {
|
||||
case string:
|
||||
return k1 == k2.(string)
|
||||
|
||||
case int64:
|
||||
return k1 == k2.(int64)
|
||||
|
||||
case int32:
|
||||
return k1 == k2.(int32)
|
||||
|
||||
case int16:
|
||||
return k1 == k2.(int16)
|
||||
|
||||
case int8:
|
||||
return k1 == k2.(int8)
|
||||
|
||||
case int:
|
||||
return k1 == k2.(int)
|
||||
|
||||
case float32:
|
||||
return k1 == k2.(float32)
|
||||
|
||||
case float64:
|
||||
return k1 == k2.(float64)
|
||||
|
||||
case uint:
|
||||
return k1 == k2.(uint)
|
||||
|
||||
case uint8:
|
||||
return k1 == k2.(uint8)
|
||||
|
||||
case uint16:
|
||||
return k1 == k2.(uint16)
|
||||
|
||||
case uint32:
|
||||
return k1 == k2.(uint32)
|
||||
|
||||
case uint64:
|
||||
return k1 == k2.(uint64)
|
||||
|
||||
case uintptr:
|
||||
return k1 == k2.(uintptr)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
91
broker/lib/topics/topics.go
Normal file
91
broker/lib/topics/topics.go
Normal file
@@ -0,0 +1,91 @@
|
||||
package topics
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/eclipse/paho.mqtt.golang/packets"
|
||||
)
|
||||
|
||||
const (
|
||||
// MWC is the multi-level wildcard
|
||||
MWC = "#"
|
||||
|
||||
// SWC is the single level wildcard
|
||||
SWC = "+"
|
||||
|
||||
// SEP is the topic level separator
|
||||
SEP = "/"
|
||||
|
||||
// SYS is the starting character of the system level topics
|
||||
SYS = "$"
|
||||
|
||||
// Both wildcards
|
||||
_WC = "#+"
|
||||
)
|
||||
|
||||
var (
|
||||
providers = make(map[string]TopicsProvider)
|
||||
)
|
||||
|
||||
// TopicsProvider
|
||||
type TopicsProvider interface {
|
||||
Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error)
|
||||
Unsubscribe(topic []byte, subscriber interface{}) error
|
||||
Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error
|
||||
Retain(msg *packets.PublishPacket) error
|
||||
Retained(topic []byte, msgs *[]*packets.PublishPacket) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
func Register(name string, provider TopicsProvider) {
|
||||
if provider == nil {
|
||||
panic("topics: Register provide is nil")
|
||||
}
|
||||
|
||||
if _, dup := providers[name]; dup {
|
||||
panic("topics: Register called twice for provider " + name)
|
||||
}
|
||||
|
||||
providers[name] = provider
|
||||
}
|
||||
|
||||
func Unregister(name string) {
|
||||
delete(providers, name)
|
||||
}
|
||||
|
||||
type Manager struct {
|
||||
p TopicsProvider
|
||||
}
|
||||
|
||||
func NewManager(providerName string) (*Manager, error) {
|
||||
p, ok := providers[providerName]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("session: unknown provider %q", providerName)
|
||||
}
|
||||
|
||||
return &Manager{p: p}, nil
|
||||
}
|
||||
|
||||
func (this *Manager) Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error) {
|
||||
return this.p.Subscribe(topic, qos, subscriber)
|
||||
}
|
||||
|
||||
func (this *Manager) Unsubscribe(topic []byte, subscriber interface{}) error {
|
||||
return this.p.Unsubscribe(topic, subscriber)
|
||||
}
|
||||
|
||||
func (this *Manager) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error {
|
||||
return this.p.Subscribers(topic, qos, subs, qoss)
|
||||
}
|
||||
|
||||
func (this *Manager) Retain(msg *packets.PublishPacket) error {
|
||||
return this.p.Retain(msg)
|
||||
}
|
||||
|
||||
func (this *Manager) Retained(topic []byte, msgs *[]*packets.PublishPacket) error {
|
||||
return this.p.Retained(topic, msgs)
|
||||
}
|
||||
|
||||
func (this *Manager) Close() error {
|
||||
return this.p.Close()
|
||||
}
|
||||
@@ -1,58 +0,0 @@
|
||||
package broker
|
||||
|
||||
import "sync"
|
||||
|
||||
const (
|
||||
MaxUser = 1024 * 1024
|
||||
MessagePoolNum = 1024
|
||||
MessagePoolUser = MaxUser / MessagePoolNum
|
||||
MessagePoolMessageNum = MaxUser / MessagePoolNum * 4
|
||||
)
|
||||
|
||||
type Message struct {
|
||||
client *client
|
||||
msg []byte
|
||||
}
|
||||
|
||||
var (
|
||||
MSGPool []MessagePool
|
||||
)
|
||||
|
||||
type MessagePool struct {
|
||||
l sync.Mutex
|
||||
maxuser int
|
||||
user int
|
||||
queue chan *Message
|
||||
}
|
||||
|
||||
func InitMessagePool() {
|
||||
MSGPool = make([]MessagePool, (MessagePoolNum + 2))
|
||||
for i := 0; i < (MessagePoolNum + 2); i++ {
|
||||
MSGPool[i].Init(MessagePoolUser, MessagePoolMessageNum)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *MessagePool) Init(num int, maxusernum int) {
|
||||
p.maxuser = maxusernum
|
||||
p.queue = make(chan *Message, num)
|
||||
}
|
||||
|
||||
func (p *MessagePool) GetPool() *MessagePool {
|
||||
p.l.Lock()
|
||||
if p.user+1 < p.maxuser {
|
||||
p.user += 1
|
||||
p.l.Unlock()
|
||||
return p
|
||||
} else {
|
||||
p.l.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (p *MessagePool) Reduce() {
|
||||
p.l.Lock()
|
||||
p.user -= 1
|
||||
p.l.Unlock()
|
||||
|
||||
}
|
||||
236
broker/packet.go
236
broker/packet.go
@@ -1,236 +0,0 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"hmq/lib/message"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
)
|
||||
|
||||
func checkError(desc string, err error) {
|
||||
if err != nil {
|
||||
log.Error(desc, " : ", err)
|
||||
}
|
||||
}
|
||||
|
||||
func ReadPacket(conn net.Conn) ([]byte, error) {
|
||||
if conn == nil {
|
||||
return nil, errors.New("conn is null")
|
||||
}
|
||||
// conn.SetReadDeadline(t)
|
||||
var buf []byte
|
||||
// read fix header
|
||||
b := make([]byte, 1)
|
||||
_, err := io.ReadFull(conn, b)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf = append(buf, b...)
|
||||
// read rem msg length
|
||||
rembuf, remlen := decodeLength(conn)
|
||||
buf = append(buf, rembuf...)
|
||||
// read rem msg
|
||||
packetBytes := make([]byte, remlen)
|
||||
_, err = io.ReadFull(conn, packetBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
buf = append(buf, packetBytes...)
|
||||
// log.Info("len buf: ", len(buf))
|
||||
return buf, nil
|
||||
}
|
||||
|
||||
func decodeLength(r io.Reader) ([]byte, int) {
|
||||
var rLength uint32
|
||||
var multiplier uint32
|
||||
var buf []byte
|
||||
b := make([]byte, 1)
|
||||
for {
|
||||
io.ReadFull(r, b)
|
||||
digit := b[0]
|
||||
buf = append(buf, b[0])
|
||||
rLength |= uint32(digit&127) << multiplier
|
||||
if (digit & 128) == 0 {
|
||||
break
|
||||
}
|
||||
multiplier += 7
|
||||
|
||||
}
|
||||
return buf, int(rLength)
|
||||
}
|
||||
|
||||
func DecodeMessage(buf []byte) (message.Message, error) {
|
||||
msgType := uint8(buf[0] & 0xF0 >> 4)
|
||||
switch msgType {
|
||||
case CONNECT:
|
||||
return DecodeConnectMessage(buf)
|
||||
case CONNACK:
|
||||
return DecodeConnackMessage(buf)
|
||||
case PUBLISH:
|
||||
return DecodePublishMessage(buf)
|
||||
case PUBACK:
|
||||
return DecodePubackMessage(buf)
|
||||
case PUBCOMP:
|
||||
return DecodePubcompMessage(buf)
|
||||
case PUBREC:
|
||||
return DecodePubrecMessage(buf)
|
||||
case PUBREL:
|
||||
return DecodePubrelMessage(buf)
|
||||
case SUBSCRIBE:
|
||||
return DecodeSubscribeMessage(buf)
|
||||
case SUBACK:
|
||||
return DecodeSubackMessage(buf)
|
||||
case UNSUBSCRIBE:
|
||||
return DecodeUnsubscribeMessage(buf)
|
||||
case UNSUBACK:
|
||||
return DecodeUnsubackMessage(buf)
|
||||
case PINGREQ:
|
||||
return DecodePingreqMessage(buf)
|
||||
case PINGRESP:
|
||||
return DecodePingrespMessage(buf)
|
||||
case DISCONNECT:
|
||||
return DecodeDisconnectMessage(buf)
|
||||
default:
|
||||
return nil, errors.New("error message type")
|
||||
}
|
||||
}
|
||||
|
||||
func DecodeConnectMessage(buf []byte) (*message.ConnectMessage, error) {
|
||||
connMsg := message.NewConnectMessage()
|
||||
_, err := connMsg.Decode(buf)
|
||||
if err != nil {
|
||||
if !message.ValidConnackError(err) {
|
||||
return nil, errors.New("Connect message format error, " + err.Error())
|
||||
}
|
||||
return nil, errors.New("Deode connect message error, " + err.Error())
|
||||
}
|
||||
return connMsg, nil
|
||||
}
|
||||
|
||||
func DecodeConnackMessage(buf []byte) (*message.ConnackMessage, error) {
|
||||
msg := message.NewConnackMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Connack message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodePublishMessage(buf []byte) (*message.PublishMessage, error) {
|
||||
msg := message.NewPublishMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Publish message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodePubackMessage(buf []byte) (*message.PubackMessage, error) {
|
||||
msg := message.NewPubackMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Puback message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodePubrecMessage(buf []byte) (*message.PubrecMessage, error) {
|
||||
msg := message.NewPubrecMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Pubrec message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodePubrelMessage(buf []byte) (*message.PubrelMessage, error) {
|
||||
msg := message.NewPubrelMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Pubrel message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodePubcompMessage(buf []byte) (*message.PubcompMessage, error) {
|
||||
msg := message.NewPubcompMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Pubcomp message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodeSubscribeMessage(buf []byte) (*message.SubscribeMessage, error) {
|
||||
msg := message.NewSubscribeMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Subscribe message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodeSubackMessage(buf []byte) (*message.SubackMessage, error) {
|
||||
msg := message.NewSubackMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Suback message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodeUnsubscribeMessage(buf []byte) (*message.UnsubscribeMessage, error) {
|
||||
msg := message.NewUnsubscribeMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Unsubscribe message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodeUnsubackMessage(buf []byte) (*message.UnsubackMessage, error) {
|
||||
msg := message.NewUnsubackMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Unsuback message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodePingreqMessage(buf []byte) (*message.PingreqMessage, error) {
|
||||
msg := message.NewPingreqMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Pingreq message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodePingrespMessage(buf []byte) (*message.PingrespMessage, error) {
|
||||
msg := message.NewPingrespMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Pingresp message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func DecodeDisconnectMessage(buf []byte) (*message.DisconnectMessage, error) {
|
||||
msg := message.NewDisconnectMessage()
|
||||
_, err := msg.Decode(buf)
|
||||
if err != nil {
|
||||
return nil, errors.New("Decode Disconnect message error, " + err.Error())
|
||||
}
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func EncodeMessage(msg message.Message) ([]byte, error) {
|
||||
buf := make([]byte, msg.Len())
|
||||
_, err := msg.Encode(buf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return buf, nil
|
||||
}
|
||||
120
broker/retain.go
120
broker/retain.go
@@ -1,120 +0,0 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"sync"
|
||||
)
|
||||
|
||||
type RetainList struct {
|
||||
sync.RWMutex
|
||||
root *rlevel
|
||||
}
|
||||
type rlevel struct {
|
||||
nodes map[string]*rnode
|
||||
}
|
||||
type rnode struct {
|
||||
next *rlevel
|
||||
msg []byte
|
||||
}
|
||||
type RetainResult struct {
|
||||
msg [][]byte
|
||||
}
|
||||
|
||||
func newRNode() *rnode {
|
||||
return &rnode{msg: make([]byte, 0, 4)}
|
||||
}
|
||||
|
||||
func newRLevel() *rlevel {
|
||||
return &rlevel{nodes: make(map[string]*rnode)}
|
||||
}
|
||||
|
||||
func NewRetainList() *RetainList {
|
||||
return &RetainList{root: newRLevel()}
|
||||
}
|
||||
|
||||
func (r *RetainList) Insert(topic, buf []byte) error {
|
||||
|
||||
tokens, err := PublishTopicCheckAndSpilt(topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// log.Info("insert tokens:", tokens)
|
||||
r.Lock()
|
||||
|
||||
l := r.root
|
||||
var n *rnode
|
||||
for _, t := range tokens {
|
||||
n = l.nodes[t]
|
||||
if n == nil {
|
||||
n = newRNode()
|
||||
l.nodes[t] = n
|
||||
}
|
||||
if n.next == nil {
|
||||
n.next = newRLevel()
|
||||
}
|
||||
l = n.next
|
||||
}
|
||||
n.msg = buf
|
||||
r.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *RetainList) Match(topic []byte) [][]byte {
|
||||
|
||||
tokens, err := SubscribeTopicCheckAndSpilt(topic)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
results := &RetainResult{}
|
||||
|
||||
r.Lock()
|
||||
l := r.root
|
||||
matchRLevel(l, tokens, results)
|
||||
r.Unlock()
|
||||
// log.Info("results: ", results)
|
||||
return results.msg
|
||||
|
||||
}
|
||||
func matchRLevel(l *rlevel, toks []string, results *RetainResult) {
|
||||
var n *rnode
|
||||
for i, t := range toks {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
// log.Info("l info :", l.nodes)
|
||||
if t == "#" {
|
||||
for _, n := range l.nodes {
|
||||
n.GetAll(results)
|
||||
}
|
||||
}
|
||||
if t == "+" {
|
||||
for _, n := range l.nodes {
|
||||
if len(t[i+1:]) == 0 {
|
||||
results.msg = append(results.msg, n.msg)
|
||||
} else {
|
||||
matchRLevel(n.next, toks[i+1:], results)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
n = l.nodes[t]
|
||||
if n != nil {
|
||||
l = n.next
|
||||
} else {
|
||||
l = nil
|
||||
}
|
||||
}
|
||||
if n != nil {
|
||||
results.msg = append(results.msg, n.msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *rnode) GetAll(results *RetainResult) {
|
||||
// log.Info("node 's message: ", string(r.msg))
|
||||
if r.msg != nil && string(r.msg) != "" {
|
||||
results.msg = append(results.msg, r.msg)
|
||||
}
|
||||
l := r.next
|
||||
for _, n := range l.nodes {
|
||||
n.GetAll(results)
|
||||
}
|
||||
}
|
||||
55
broker/sesson.go
Normal file
55
broker/sesson.go
Normal file
@@ -0,0 +1,55 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package broker
|
||||
|
||||
import "github.com/eclipse/paho.mqtt.golang/packets"
|
||||
|
||||
func (b *Broker) getSession(cli *client, req *packets.ConnectPacket, resp *packets.ConnackPacket) error {
|
||||
// If CleanSession is set to 0, the server MUST resume communications with the
|
||||
// client based on state from the current session, as identified by the client
|
||||
// identifier. If there is no session associated with the client identifier the
|
||||
// server must create a new session.
|
||||
//
|
||||
// If CleanSession is set to 1, the client and server must discard any previous
|
||||
// session and start a new one. b session lasts as long as the network c
|
||||
// onnection. State data associated with b session must not be reused in any
|
||||
// subsequent session.
|
||||
|
||||
var err error
|
||||
|
||||
// Check to see if the client supplied an ID, if not, generate one and set
|
||||
// clean session.
|
||||
|
||||
if len(req.ClientIdentifier) == 0 {
|
||||
req.CleanSession = true
|
||||
}
|
||||
|
||||
cid := req.ClientIdentifier
|
||||
|
||||
// If CleanSession is NOT set, check the session store for existing session.
|
||||
// If found, return it.
|
||||
if !req.CleanSession {
|
||||
if cli.session, err = b.sessionMgr.Get(cid); err == nil {
|
||||
resp.SessionPresent = true
|
||||
|
||||
if err := cli.session.Update(req); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If CleanSession, or no existing session found, then create a new one
|
||||
if cli.session == nil {
|
||||
if cli.session, err = b.sessionMgr.New(cid); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
resp.SessionPresent = false
|
||||
|
||||
if err := cli.session.Init(req); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -1,318 +0,0 @@
|
||||
package broker
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
)
|
||||
|
||||
// A result structure better optimized for queue subs.
|
||||
type SublistResult struct {
|
||||
psubs []*subscription
|
||||
qsubs []*subscription // don't make this a map, too expensive to iterate
|
||||
}
|
||||
|
||||
// A Sublist stores and efficiently retrieves subscriptions.
|
||||
type Sublist struct {
|
||||
sync.RWMutex
|
||||
cache map[string]*SublistResult
|
||||
root *level
|
||||
}
|
||||
|
||||
// A node contains subscriptions and a pointer to the next level.
|
||||
type node struct {
|
||||
next *level
|
||||
psubs []*subscription
|
||||
qsubs []*subscription
|
||||
}
|
||||
|
||||
// A level represents a group of nodes and special pointers to
|
||||
// wildcard nodes.
|
||||
type level struct {
|
||||
nodes map[string]*node
|
||||
}
|
||||
|
||||
// Create a new default node.
|
||||
func newNode() *node {
|
||||
return &node{psubs: make([]*subscription, 0, 4), qsubs: make([]*subscription, 0, 4)}
|
||||
}
|
||||
|
||||
// Create a new default level. We use FNV1A as the hash
|
||||
// algortihm for the tokens, which should be short.
|
||||
func newLevel() *level {
|
||||
return &level{nodes: make(map[string]*node)}
|
||||
}
|
||||
|
||||
// New will create a default sublist
|
||||
func NewSublist() *Sublist {
|
||||
return &Sublist{root: newLevel(), cache: make(map[string]*SublistResult)}
|
||||
}
|
||||
|
||||
// Insert adds a subscription into the sublist
|
||||
func (s *Sublist) Insert(sub *subscription) error {
|
||||
|
||||
tokens, err := SubscribeTopicCheckAndSpilt(sub.topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Lock()
|
||||
|
||||
l := s.root
|
||||
var n *node
|
||||
for _, t := range tokens {
|
||||
n = l.nodes[t]
|
||||
if n == nil {
|
||||
n = newNode()
|
||||
l.nodes[t] = n
|
||||
}
|
||||
if n.next == nil {
|
||||
n.next = newLevel()
|
||||
}
|
||||
l = n.next
|
||||
}
|
||||
if sub.queue {
|
||||
//check qsub is already exist
|
||||
for i := range n.qsubs {
|
||||
if equal(n.qsubs[i], sub) {
|
||||
n.qsubs[i] = sub
|
||||
return nil
|
||||
}
|
||||
}
|
||||
n.qsubs = append(n.qsubs, sub)
|
||||
} else {
|
||||
//check psub is already exist
|
||||
for i := range n.psubs {
|
||||
if equal(n.psubs[i], sub) {
|
||||
n.psubs[i] = sub
|
||||
return nil
|
||||
}
|
||||
}
|
||||
n.psubs = append(n.psubs, sub)
|
||||
}
|
||||
|
||||
topic := string(sub.topic)
|
||||
s.addToCache(topic, sub)
|
||||
s.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Sublist) addToCache(topic string, sub *subscription) {
|
||||
for k, r := range s.cache {
|
||||
if matchLiteral(k, topic) {
|
||||
// Copy since others may have a reference.
|
||||
nr := copyResult(r)
|
||||
if sub.queue == false {
|
||||
nr.psubs = append(nr.psubs, sub)
|
||||
} else {
|
||||
nr.qsubs = append(nr.qsubs, sub)
|
||||
}
|
||||
s.cache[k] = nr
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Sublist) removeFromCache(topic string, sub *subscription) {
|
||||
for k := range s.cache {
|
||||
if !matchLiteral(k, topic) {
|
||||
continue
|
||||
}
|
||||
// Since someone else may be referecing, can't modify the list
|
||||
// safely, just let it re-populate.
|
||||
delete(s.cache, k)
|
||||
}
|
||||
}
|
||||
|
||||
func matchLiteral(literal, topic string) bool {
|
||||
tok, _ := SubscribeTopicCheckAndSpilt([]byte(topic))
|
||||
li, _ := PublishTopicCheckAndSpilt([]byte(literal))
|
||||
|
||||
for i := 0; i < len(tok); i++ {
|
||||
b := tok[i]
|
||||
switch b {
|
||||
case "+":
|
||||
|
||||
case "#":
|
||||
return true
|
||||
default:
|
||||
if b != li[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Deep copy
|
||||
func copyResult(r *SublistResult) *SublistResult {
|
||||
nr := &SublistResult{}
|
||||
nr.psubs = append([]*subscription(nil), r.psubs...)
|
||||
nr.qsubs = append([]*subscription(nil), r.qsubs...)
|
||||
return nr
|
||||
}
|
||||
|
||||
func (s *Sublist) Remove(sub *subscription) error {
|
||||
tokens, err := SubscribeTopicCheckAndSpilt(sub.topic)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Lock()
|
||||
defer s.Unlock()
|
||||
|
||||
l := s.root
|
||||
var n *node
|
||||
|
||||
for _, t := range tokens {
|
||||
if l == nil {
|
||||
return errors.New("No Matches subscription Found")
|
||||
}
|
||||
n = l.nodes[t]
|
||||
if n != nil {
|
||||
l = n.next
|
||||
} else {
|
||||
l = nil
|
||||
}
|
||||
}
|
||||
if !s.removeFromNode(n, sub) {
|
||||
return errors.New("No Matches subscription Found")
|
||||
}
|
||||
topic := string(sub.topic)
|
||||
s.removeFromCache(topic, sub)
|
||||
return nil
|
||||
|
||||
}
|
||||
|
||||
func (s *Sublist) removeFromNode(n *node, sub *subscription) (found bool) {
|
||||
if n == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if sub.queue {
|
||||
n.qsubs, found = removeSubFromList(sub, n.qsubs)
|
||||
return found
|
||||
} else {
|
||||
n.psubs, found = removeSubFromList(sub, n.psubs)
|
||||
return found
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *Sublist) Match(topic string) *SublistResult {
|
||||
s.RLock()
|
||||
rc, ok := s.cache[topic]
|
||||
s.RUnlock()
|
||||
|
||||
if ok {
|
||||
return rc
|
||||
}
|
||||
|
||||
tokens, err := PublishTopicCheckAndSpilt([]byte(topic))
|
||||
if err != nil {
|
||||
log.Error("\tserver/sublist.go: ", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
result := &SublistResult{}
|
||||
|
||||
s.Lock()
|
||||
l := s.root
|
||||
if len(tokens) > 0 {
|
||||
if tokens[0] == "/" {
|
||||
if _, exist := l.nodes["#"]; exist {
|
||||
addNodeToResults(l.nodes["#"], result)
|
||||
}
|
||||
if _, exist := l.nodes["+"]; exist {
|
||||
matchLevel(l.nodes["/"].next, tokens[1:], result)
|
||||
}
|
||||
if _, exist := l.nodes["/"]; exist {
|
||||
matchLevel(l.nodes["/"].next, tokens[1:], result)
|
||||
}
|
||||
} else {
|
||||
matchLevel(s.root, tokens, result)
|
||||
}
|
||||
}
|
||||
s.cache[topic] = result
|
||||
if len(s.cache) > 1024 {
|
||||
for k := range s.cache {
|
||||
delete(s.cache, k)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
s.Unlock()
|
||||
// log.Info("SublistResult: ", result)
|
||||
return result
|
||||
}
|
||||
|
||||
func matchLevel(l *level, toks []string, results *SublistResult) {
|
||||
var swc, n *node
|
||||
exist := false
|
||||
for i, t := range toks {
|
||||
if l == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if _, exist = l.nodes["#"]; exist {
|
||||
addNodeToResults(l.nodes["#"], results)
|
||||
}
|
||||
if t != "/" {
|
||||
if swc, exist = l.nodes["+"]; exist {
|
||||
matchLevel(l.nodes["+"].next, toks[i+1:], results)
|
||||
}
|
||||
} else {
|
||||
if _, exist = l.nodes["+"]; exist {
|
||||
addNodeToResults(l.nodes["+"], results)
|
||||
}
|
||||
}
|
||||
|
||||
n = l.nodes[t]
|
||||
if n != nil {
|
||||
l = n.next
|
||||
} else {
|
||||
l = nil
|
||||
}
|
||||
}
|
||||
if n != nil {
|
||||
addNodeToResults(n, results)
|
||||
}
|
||||
if swc != nil {
|
||||
addNodeToResults(n, results)
|
||||
}
|
||||
}
|
||||
|
||||
// This will add in a node's results to the total results.
|
||||
func addNodeToResults(n *node, results *SublistResult) {
|
||||
results.psubs = append(results.psubs, n.psubs...)
|
||||
results.qsubs = append(results.qsubs, n.qsubs...)
|
||||
}
|
||||
|
||||
func removeSubFromList(sub *subscription, sl []*subscription) ([]*subscription, bool) {
|
||||
for i := 0; i < len(sl); i++ {
|
||||
if sl[i] == sub {
|
||||
last := len(sl) - 1
|
||||
sl[i] = sl[last]
|
||||
sl[last] = nil
|
||||
sl = sl[:last]
|
||||
// log.Info("removeSubFromList success")
|
||||
return shrinkAsNeeded(sl), true
|
||||
}
|
||||
}
|
||||
return sl, false
|
||||
}
|
||||
|
||||
// Checks if we need to do a resize. This is for very large growth then
|
||||
// subsequent return to a more normal size from unsubscribe.
|
||||
func shrinkAsNeeded(sl []*subscription) []*subscription {
|
||||
lsl := len(sl)
|
||||
csl := cap(sl)
|
||||
// Don't bother if list not too big
|
||||
if csl <= 8 {
|
||||
return sl
|
||||
}
|
||||
pFree := float32(csl-lsl) / float32(csl)
|
||||
if pFree > 0.50 {
|
||||
return append([]*subscription(nil), sl...)
|
||||
}
|
||||
return sl
|
||||
}
|
||||
24
broker/usage.go
Normal file
24
broker/usage.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package broker
|
||||
|
||||
var usageStr = `
|
||||
Usage: hmq [options]
|
||||
|
||||
Broker Options:
|
||||
-w, --worker <number> Worker num to process message, perfer (client num)/10. (default 1024)
|
||||
-p, --port <port> Use port for clients (default: 1883)
|
||||
--host <host> Network host to listen on. (default "0.0.0.0")
|
||||
-ws, --wsport <port> Use port for websocket monitoring
|
||||
-wsp,--wspath <path> Use path for websocket monitoring
|
||||
-c, --config <file> Configuration file
|
||||
|
||||
Logging Options:
|
||||
-d, --debug <bool> Enable debugging output (default false)
|
||||
-D Debug and trace
|
||||
|
||||
Cluster Options:
|
||||
-r, --router <rurl> Router who maintenance cluster info
|
||||
-cp, --clusterport <cluster-port> Cluster listen port for others
|
||||
|
||||
Common Options:
|
||||
-h, --help Show this message
|
||||
`
|
||||
@@ -1,37 +0,0 @@
|
||||
package broker
|
||||
|
||||
type Worker struct {
|
||||
WorkerPool chan chan *Message
|
||||
MsgChannel chan *Message
|
||||
quit chan bool
|
||||
}
|
||||
|
||||
func NewWorker(workerPool chan chan *Message) Worker {
|
||||
return Worker{
|
||||
WorkerPool: workerPool,
|
||||
MsgChannel: make(chan *Message),
|
||||
quit: make(chan bool)}
|
||||
}
|
||||
|
||||
func (w Worker) Start() {
|
||||
go func() {
|
||||
for {
|
||||
// register the current worker into the worker queue.
|
||||
w.WorkerPool <- w.MsgChannel
|
||||
select {
|
||||
case msg := <-w.MsgChannel:
|
||||
// we have received a work request.
|
||||
ProcessMessage(msg)
|
||||
case <-w.quit:
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Stop signals the worker to stop listening for work requests.
|
||||
func (w Worker) Stop() {
|
||||
go func() {
|
||||
w.quit <- true
|
||||
}()
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
## pub 1 , sub 2, pubsub 3
|
||||
## %c is clientid , %s is username
|
||||
##auth type value pub/sub topic
|
||||
allow ip 127.0.0.1 2 $SYS/#
|
||||
allow clientid 0001 3 #
|
||||
deny username admin 3 #
|
||||
allow username joy 3 /test,hello/world
|
||||
allow clientid * 1 toCloud/%c
|
||||
allow username * 1 toCloud/%u
|
||||
allow clientid * 2 toDevice/%c
|
||||
allow username * 2 toDevice/%u
|
||||
deny clientid * 3 #
|
||||
@@ -1,11 +1,13 @@
|
||||
{
|
||||
"workerNum": 4096,
|
||||
"port": "1883",
|
||||
"host": "0.0.0.0",
|
||||
"debug": true,
|
||||
"cluster": {
|
||||
"host": "0.0.0.0",
|
||||
"port": "1993",
|
||||
"routes": []
|
||||
"port": "1993"
|
||||
},
|
||||
"router": "127.0.0.1:9888",
|
||||
"tlsPort": "8883",
|
||||
"tlsHost": "0.0.0.0",
|
||||
"wsPort": "1888",
|
||||
@@ -17,6 +19,8 @@
|
||||
"certFile": "ssl/server/cert.pem",
|
||||
"keyFile": "ssl/server/key.pem"
|
||||
},
|
||||
"acl": true,
|
||||
"aclConf": "conf/acl.conf"
|
||||
}
|
||||
"plugins": {
|
||||
"auth": "authhttp",
|
||||
"bridge": "kafka"
|
||||
}
|
||||
}
|
||||
38
deploy/config.yaml
Normal file
38
deploy/config.yaml
Normal file
@@ -0,0 +1,38 @@
|
||||
apiVersion: v1
|
||||
kind: ConfigMap
|
||||
metadata:
|
||||
name: mqtt-broker
|
||||
data:
|
||||
hmq.config: |
|
||||
{
|
||||
"workerNum": 4096,
|
||||
"port": "1883",
|
||||
"host": "0.0.0.0",
|
||||
"plugins": ["authhttp","kafka"]
|
||||
}
|
||||
|
||||
kafka.json: |
|
||||
{
|
||||
"addr": [
|
||||
"127.0.0.1:9090"
|
||||
],
|
||||
"onConnect": "onConnect",
|
||||
"onPublish": "onPublish",
|
||||
"onSubscribe": "onSubscribe",
|
||||
"onDisconnect": "onDisconnect",
|
||||
"onUnsubscribe": "onUnsubscribe",
|
||||
"regexpMap": [
|
||||
{
|
||||
"^/(.+)/(.+)/upload/(.*)$": "upload"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
authhttp.json: |
|
||||
{
|
||||
"auth": "http://127.0.0.1:9090/mqtt/auth",
|
||||
"acl": "http://127.0.0.1:9090/mqtt/acl",
|
||||
"super": "http://127.0.0.1:9090/mqtt/superuser"
|
||||
}
|
||||
|
||||
|
||||
44
deploy/deploy.yaml
Normal file
44
deploy/deploy.yaml
Normal file
@@ -0,0 +1,44 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: mqtt-broker
|
||||
spec:
|
||||
selector:
|
||||
matchLabels:
|
||||
app: mqtt-broker
|
||||
replicas: 1
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: mqtt-broker
|
||||
spec:
|
||||
containers:
|
||||
- name: mqtt-broker
|
||||
image: uhub.service.ucloud.cn/uiot_core_hub/hmq:v0.1.0
|
||||
ports:
|
||||
- containerPort: 1883
|
||||
- containerPort: 8080
|
||||
volumeMounts:
|
||||
- name: mqtt-broker1
|
||||
mountPath: /conf
|
||||
subPath: hmq.config
|
||||
- name: mqtt-broker1
|
||||
mountPath: /plugins/kafka/kafka.json
|
||||
subPath: kafka.json
|
||||
- name: mqtt-broker1
|
||||
mountPath: /plugins/authttp/http.json
|
||||
subPath: kafka.json
|
||||
volumes:
|
||||
- name: mqtt-broker1
|
||||
configMap:
|
||||
name: mqtt-broker1
|
||||
items:
|
||||
- key: hmq.config
|
||||
path: hmq.config
|
||||
items:
|
||||
- key: http.json
|
||||
path: http.json
|
||||
items:
|
||||
- key: kafka.json
|
||||
path: kafka.json
|
||||
|
||||
13
deploy/svc.yaml
Normal file
13
deploy/svc.yaml
Normal file
@@ -0,0 +1,13 @@
|
||||
kind: Service
|
||||
apiVersion: v1
|
||||
metadata:
|
||||
name: mqtt-broker
|
||||
spec:
|
||||
selector:
|
||||
app: mqtt-broker
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 8080
|
||||
targetPort: 8080
|
||||
type: ClusterIP
|
||||
sessionAffinity: ClientIP
|
||||
30
go.mod
Normal file
30
go.mod
Normal file
@@ -0,0 +1,30 @@
|
||||
module github.com/fhmq/hmq
|
||||
|
||||
go 1.12
|
||||
|
||||
require (
|
||||
github.com/Shopify/sarama v1.23.0
|
||||
github.com/StackExchange/wmi v0.0.0-20181212234831-e0a55b97c705 // indirect
|
||||
github.com/bitly/go-simplejson v0.5.0
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 // indirect
|
||||
github.com/eclipse/paho.mqtt.golang v1.2.0
|
||||
github.com/gin-gonic/gin v1.4.0
|
||||
github.com/go-ole/go-ole v1.2.4 // indirect
|
||||
github.com/golang/protobuf v1.3.2 // indirect
|
||||
github.com/kr/pretty v0.1.0 // indirect
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible
|
||||
github.com/pkg/errors v0.8.1 // indirect
|
||||
github.com/satori/go.uuid v1.2.0
|
||||
github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e
|
||||
github.com/shirou/gopsutil v2.18.12+incompatible
|
||||
github.com/stretchr/testify v1.3.0
|
||||
github.com/tidwall/gjson v1.3.0
|
||||
go.uber.org/atomic v1.4.0 // indirect
|
||||
go.uber.org/multierr v1.1.0 // indirect
|
||||
go.uber.org/zap v1.10.0
|
||||
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 // indirect
|
||||
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80
|
||||
golang.org/x/sys v0.0.0-20190730183949-1393eb018365 // indirect
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 // indirect
|
||||
gopkg.in/jcmturner/goidentity.v3 v3.0.0 // indirect
|
||||
)
|
||||
125
go.sum
Normal file
125
go.sum
Normal file
@@ -0,0 +1,125 @@
|
||||
github.com/DataDog/zstd v1.3.6-0.20190409195224-796139022798 h1:2T/jmrHeTezcCM58lvEQXs0UpQJCo5SoGAcg+mbSTIg=
|
||||
github.com/DataDog/zstd v1.3.6-0.20190409195224-796139022798/go.mod h1:1jcaCB/ufaK+sKp1NBhlGmpz41jOoPQ35bpF36t7BBo=
|
||||
github.com/Shopify/sarama v1.23.0 h1:slvlbm7bxyp7sKQbUwha5BQdZTqurhRoI+zbKorVigQ=
|
||||
github.com/Shopify/sarama v1.23.0/go.mod h1:XLH1GYJnLVE0XCr6KdJGVJRTwY30moWNJ4sERjXX6fs=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible h1:TKdv8HiTLgE5wdJuEML90aBgNWsokNbMijUGhmcoBJc=
|
||||
github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI=
|
||||
github.com/StackExchange/wmi v0.0.0-20181212234831-e0a55b97c705 h1:UUppSQnhf4Yc6xGxSkoQpPhb7RVzuv5Nb1mwJ5VId9s=
|
||||
github.com/StackExchange/wmi v0.0.0-20181212234831-e0a55b97c705/go.mod h1:3eOhrUMpNV+6aFIbp5/iudMxNCF27Vw2OZgy4xEx0Fg=
|
||||
github.com/bitly/go-simplejson v0.5.0 h1:6IH+V8/tVMab511d5bn4M7EwGXZf9Hj6i2xSwkNEM+Y=
|
||||
github.com/bitly/go-simplejson v0.5.0/go.mod h1:cXHtHw4XUPsvGaxgjIAn8PhEWG9NfngEKAMDJEczWVA=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY=
|
||||
github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4=
|
||||
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/eapache/go-resiliency v1.1.0 h1:1NtRmCAqadE2FN4ZcN6g90TP3uk8cg9rn9eNK2197aU=
|
||||
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
|
||||
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 h1:YEetp8/yCZMuEPMUDHG0CW/brkkEp8mzqk2+ODEitlw=
|
||||
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=
|
||||
github.com/eapache/queue v1.1.0 h1:YOEu7KNc61ntiQlcEeUIoDTJ2o8mQznoNvUhiigpIqc=
|
||||
github.com/eapache/queue v1.1.0/go.mod h1:6eCeP0CKFpHLu8blIFXhExK/dRa7WDZfr6jVFPTqq+I=
|
||||
github.com/eclipse/paho.mqtt.golang v1.2.0 h1:1F8mhG9+aO5/xpdtFkW4SxOJB67ukuDC3t2y2qayIX0=
|
||||
github.com/eclipse/paho.mqtt.golang v1.2.0/go.mod h1:H9keYFcgq3Qr5OUJm/JZI/i6U7joQ8SYLhZwfeOo6Ts=
|
||||
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3 h1:t8FVkw33L+wilf2QiWkw0UV77qRpcH/JHPKGpKa2E8g=
|
||||
github.com/gin-contrib/sse v0.0.0-20190301062529-5545eab6dad3/go.mod h1:VJ0WA2NBN22VlZ2dKZQPAPnyWw5XTlK1KymzLKsr59s=
|
||||
github.com/gin-gonic/gin v1.4.0 h1:3tMoCCfM7ppqsR0ptz/wi1impNpT7/9wQtMZ8lr1mCQ=
|
||||
github.com/gin-gonic/gin v1.4.0/go.mod h1:OW2EZn3DO8Ln9oIKOvM++LBO+5UPHJJDH72/q/3rZdM=
|
||||
github.com/go-ole/go-ole v1.2.4 h1:nNBDSCOigTSiarFpYE9J/KtEA1IOW4CNeqT9TQDqCxI=
|
||||
github.com/go-ole/go-ole v1.2.4/go.mod h1:XCwSNxSkXRo4vlyPy93sltvi/qJq0jqQhjqQNIwKuxM=
|
||||
github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
|
||||
github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
|
||||
github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
|
||||
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
|
||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||
github.com/hashicorp/go-uuid v1.0.1 h1:fv1ep09latC32wFoVwnqcnKJGnMSdBanPczbHAYm1BE=
|
||||
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
|
||||
github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03 h1:FUwcHNlEqkqLjLBdCp5PRlCFijNjvcYANOZXzCfXwCM=
|
||||
github.com/jcmturner/gofork v0.0.0-20190328161633-dc7c13fece03/go.mod h1:MK8+TM0La+2rjBD4jE12Kj1pCCxK7d2LK/UM3ncEo0o=
|
||||
github.com/json-iterator/go v1.1.6 h1:MrUvLMLTMxbqFJ9kzlvat/rYZqZnW3u4wkLzWTaFwKs=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/mattn/go-isatty v0.0.7 h1:UvyT9uN+3r7yLEYSlJsbQGdsaB/a0DlgWP3pql6iwOc=
|
||||
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg=
|
||||
github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
github.com/modern-go/reflect2 v1.0.1 h1:9f412s+6RmYXLWZSEzVVgPGK7C2PphHj5RJrvfx9AWI=
|
||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible h1:HRMgzkcYKYpi3C8ajMPV8OFXaaRUnok+kx1WdO15EQc=
|
||||
github.com/patrickmn/go-cache v2.1.0+incompatible/go.mod h1:3Qf8kWWT7OJRJbdiICTKqZju1ZixQ/KpMGzzAfe6+WQ=
|
||||
github.com/pierrec/lz4 v0.0.0-20190327172049-315a67e90e41 h1:GeinFsrjWz97fAxVUEd748aV0cYL+I6k44gFJTCVvpU=
|
||||
github.com/pierrec/lz4 v0.0.0-20190327172049-315a67e90e41/go.mod h1:3/3N9NVKO0jef7pBehbT1qWhCMrIgbYNnFAZCqQ5LRc=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/profile v1.2.1/go.mod h1:hJw3o1OdXxsrSjjVksARp5W95eeEaEfptyVZyv6JUPA=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a h1:9ZKAASQSHhDYGoxY8uLVpewe1GDZ2vu2Tr/vTdVAkFQ=
|
||||
github.com/rcrowley/go-metrics v0.0.0-20181016184325-3113b8401b8a/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||
github.com/satori/go.uuid v1.2.0 h1:0uYX9dsZ2yD7q2RtLRtPSdGDWzjeM3TbMJP9utgA0ww=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e h1:uO75wNGioszjmIzcY/tvdDYKRLVvzggtAmmJkn9j4GQ=
|
||||
github.com/segmentio/fasthash v0.0.0-20180216231524-a72b379d632e/go.mod h1:tm/wZFQ8e24NYaBGIlnO2WGCAi67re4HHuOm0sftE/M=
|
||||
github.com/shirou/gopsutil v2.18.12+incompatible h1:1eaJvGomDnH74/5cF4CTmTbLHAriGFsTZppLXDX93OM=
|
||||
github.com/shirou/gopsutil v2.18.12+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/testify v1.3.0 h1:TivCn/peBQ7UY8ooIcPgZFpTNSz0Q2U6UrFlUfqbe0Q=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/tidwall/gjson v1.3.0 h1:kfpsw1W3trbg4Xm6doUtqSl9+LhLB6qJ9PkltVAQZYs=
|
||||
github.com/tidwall/gjson v1.3.0/go.mod h1:P256ACg0Mn+j1RXIDXoss50DeIABTYK1PULOJHhxOls=
|
||||
github.com/tidwall/match v1.0.1 h1:PnKP62LPNxHKTwvHHZZzdOAOCtsJTjo6dZLCwpKm5xc=
|
||||
github.com/tidwall/match v1.0.1/go.mod h1:LujAq0jyVjBy028G1WhWfIzbpQfMO8bBZ6Tyb0+pL9E=
|
||||
github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4=
|
||||
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
|
||||
github.com/ugorji/go v1.1.4 h1:j4s+tAvLfL3bZyefP2SEWmhBzmuIlH/eqNuPdFPgngw=
|
||||
github.com/ugorji/go v1.1.4/go.mod h1:uQMGLiO92mf5W77hV/PUCpI3pbzQx3CRekS0kk+RGrc=
|
||||
github.com/xdg/scram v0.0.0-20180814205039-7eeb5667e42c/go.mod h1:lB8K/P019DLNhemzwFU4jHLhdvlE6uDZjXFejJXr49I=
|
||||
github.com/xdg/stringprep v1.0.0/go.mod h1:Jhud4/sHMO4oL310DaZAKk9ZaJ08SJfe+sJh0HrGL1Y=
|
||||
go.uber.org/atomic v1.4.0 h1:cxzIVoETapQEqDhQu3QfnvXAV4AlzcvUCxkVUFw3+EU=
|
||||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/multierr v1.1.0 h1:HoEmRHQPVSqub6w2z2d2EOVs2fjyFRGyofhKuyDq0QI=
|
||||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||
go.uber.org/zap v1.10.0 h1:ORx85nbTijNz8ljznvCMR1ZBIPKFn3jQrag10X2AsuM=
|
||||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190404164418-38d8ce5564a5/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4 h1:HuIa8hRrWRSrqYzx1qI49NNxhdi2PrY7gxVSq1JjLDc=
|
||||
golang.org/x/crypto v0.0.0-20190701094942-4def268fd1a4/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c h1:uOCk1iQW6Vc18bnC13MfzScl+wdKBmM9Y9kU7Z83/lw=
|
||||
golang.org/x/net v0.0.0-20190503192946-f4e77d36d62c/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80 h1:Ao/3l156eZf2AW5wK8a7/smtodRU+gha3+BeqJ69lRk=
|
||||
golang.org/x/net v0.0.0-20190724013045-ca1201d0de80/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a h1:1BGLXjeY4akVXGgbC9HugT3Jv3hCI0z56oJR5vAMgBU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190730183949-1393eb018365 h1:SaXEMXhWzMJThc05vu6uh61Q245r4KaWMrsTedk0FDc=
|
||||
golang.org/x/sys v0.0.0-20190730183949-1393eb018365/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1 h1:xoYuJVE7KT85PYWrN730RguIQO0ePzVRfFMXadIrXTM=
|
||||
gopkg.in/go-playground/assert.v1 v1.2.1/go.mod h1:9RXL0bg/zibRAgZUYszZSwO/z8Y/a8bDuhia5mkpMnE=
|
||||
gopkg.in/go-playground/validator.v8 v8.18.2 h1:lFB4DoMU6B626w8ny76MV7VX6W2VHct2GVOI3xgiMrQ=
|
||||
gopkg.in/go-playground/validator.v8 v8.18.2/go.mod h1:RX2a/7Ha8BgOhfk7j780h4/u/RRjR0eouCJSH80/M2Y=
|
||||
gopkg.in/jcmturner/aescts.v1 v1.0.1 h1:cVVZBK2b1zY26haWB4vbBiZrfFQnfbTVrE3xZq6hrEw=
|
||||
gopkg.in/jcmturner/aescts.v1 v1.0.1/go.mod h1:nsR8qBOg+OucoIW+WMhB3GspUQXq9XorLnQb9XtvcOo=
|
||||
gopkg.in/jcmturner/dnsutils.v1 v1.0.1 h1:cIuC1OLRGZrld+16ZJvvZxVJeKPsvd5eUIvxfoN5hSM=
|
||||
gopkg.in/jcmturner/dnsutils.v1 v1.0.1/go.mod h1:m3v+5svpVOhtFAP/wSz+yzh4Mc0Fg7eRhxkJMWSIz9Q=
|
||||
gopkg.in/jcmturner/goidentity.v3 v3.0.0 h1:1duIyWiTaYvVx3YX2CYtpJbUFd7/UuPYCfgXtQ3VTbI=
|
||||
gopkg.in/jcmturner/goidentity.v3 v3.0.0/go.mod h1:oG2kH0IvSYNIu80dVAyu/yoefjq1mNfM5bm88whjWx4=
|
||||
gopkg.in/jcmturner/gokrb5.v7 v7.2.3 h1:hHMV/yKPwMnJhPuPx7pH2Uw/3Qyf+thJYlisUc44010=
|
||||
gopkg.in/jcmturner/gokrb5.v7 v7.2.3/go.mod h1:l8VISx+WGYp+Fp7KRbsiUuXTTOnxIc3Tuvyavf11/WM=
|
||||
gopkg.in/jcmturner/rpc.v1 v1.1.0 h1:QHIUxTX1ISuAv9dD2wJ9HWQVuWDX/Zc0PfeC2tjc4rU=
|
||||
gopkg.in/jcmturner/rpc.v1 v1.1.0/go.mod h1:YIdkC4XfD6GXbzje11McwsDuOlZQSb9W4vfLvuNnlv8=
|
||||
gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
@@ -1,16 +0,0 @@
|
||||
# This is the official list of SurgeMQ authors for copyright purposes.
|
||||
|
||||
# If you are submitting a patch, please add your name or the name of the
|
||||
# organization which holds the copyright to this list in alphabetical order.
|
||||
|
||||
# Names should be added to this file as
|
||||
# Name <email address>
|
||||
# The email address is not required for organizations.
|
||||
# Please keep the list sorted.
|
||||
|
||||
|
||||
# Individual Persons
|
||||
|
||||
Jian Zhen <zhenjl@gmail.com>
|
||||
|
||||
# Organizations
|
||||
@@ -1,36 +0,0 @@
|
||||
# Contributing Guidelines
|
||||
|
||||
## Reporting Issues
|
||||
|
||||
Before creating a new Issue, please check first if a similar Issue [already exists](https://github.com/surgemq/message/issues?state=open) or was [recently closed](https://github.com/surgemq/message/issues?direction=desc&page=1&sort=updated&state=closed).
|
||||
|
||||
Please provide the following minimum information:
|
||||
* Your SurgeMQ version (or git SHA)
|
||||
* Your Go version (run `go version` in your console)
|
||||
* A detailed issue description
|
||||
* Error Log if present
|
||||
* If possible, a short example
|
||||
|
||||
|
||||
## Contributing Code
|
||||
|
||||
By contributing to this project, you share your code under the Apache License, Version 2.0, as specified in the LICENSE file.
|
||||
Don't forget to add yourself to the AUTHORS file.
|
||||
|
||||
### Pull Requests Checklist
|
||||
|
||||
Please check the following points before submitting your pull request:
|
||||
- [x] Code compiles correctly
|
||||
- [x] Created tests, if possible
|
||||
- [x] All tests pass
|
||||
- [x] Extended the README / documentation, if necessary
|
||||
- [x] Added yourself to the AUTHORS file
|
||||
|
||||
### Code Review
|
||||
|
||||
Everyone is invited to review and comment on pull requests.
|
||||
If it looks fine to you, comment with "LGTM" (Looks good to me).
|
||||
|
||||
If changes are required, notice the reviewers with "PTAL" (Please take another look) after committing the fixes.
|
||||
|
||||
Before merging the Pull Request, at least one [team member](https://github.com/orgs/surgemq/people) must have commented with "LGTM".
|
||||
@@ -1,136 +0,0 @@
|
||||
Package message is an encoder/decoder library for MQTT 3.1 and 3.1.1 messages. You can
|
||||
find the MQTT specs at the following locations:
|
||||
|
||||
> 3.1.1 - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/
|
||||
> 3.1 - http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
|
||||
|
||||
From the spec:
|
||||
|
||||
> MQTT is a Client Server publish/subscribe messaging transport protocol. It is
|
||||
> light weight, open, simple, and designed so as to be easy to implement. These
|
||||
> characteristics make it ideal for use in many situations, including constrained
|
||||
> environments such as for communication in Machine to Machine (M2M) and Internet
|
||||
> of Things (IoT) contexts where a small code footprint is required and/or network
|
||||
> bandwidth is at a premium.
|
||||
>
|
||||
> The MQTT protocol works by exchanging a series of MQTT messages in a defined way.
|
||||
> The protocol runs over TCP/IP, or over other network protocols that provide
|
||||
> ordered, lossless, bi-directional connections.
|
||||
|
||||
|
||||
There are two main items to take note in this package. The first is
|
||||
|
||||
```
|
||||
type MessageType byte
|
||||
```
|
||||
|
||||
MessageType is the type representing the MQTT packet types. In the MQTT spec, MQTT
|
||||
control packet type is represented as a 4-bit unsigned value. MessageType receives
|
||||
several methods that returns string representations of the names and descriptions.
|
||||
|
||||
Also, one of the methods is New(). It returns a new Message object based on the mtype
|
||||
parameter. For example:
|
||||
|
||||
```
|
||||
m, err := CONNECT.New()
|
||||
msg := m.(*ConnectMessage)
|
||||
```
|
||||
|
||||
This would return a PublishMessage struct, but mapped to the Message interface. You can
|
||||
then type assert it back to a *PublishMessage. Another way to create a new
|
||||
PublishMessage is to call
|
||||
|
||||
```
|
||||
msg := NewConnectMessage()
|
||||
```
|
||||
|
||||
Every message type has a New function that returns a new message. The list of available
|
||||
message types are defined as constants below.
|
||||
|
||||
As you may have noticed, the second important item is the Message interface. It defines
|
||||
several methods that are common to all messages, including Name(), Desc(), and Type().
|
||||
Most importantly, it also defines the Encode() and Decode() methods.
|
||||
|
||||
```
|
||||
Encode() (io.Reader, int, error)
|
||||
Decode(io.Reader) (int, error)
|
||||
```
|
||||
|
||||
Encode returns an io.Reader in which the encoded bytes can be read. The second return
|
||||
value is the number of bytes encoded, so the caller knows how many bytes there will be.
|
||||
If Encode returns an error, then the first two return values should be considered invalid.
|
||||
Any changes to the message after Encode() is called will invalidate the io.Reader.
|
||||
|
||||
Decode reads from the io.Reader parameter until a full message is decoded, or when io.Reader
|
||||
returns EOF or error. The first return value is the number of bytes read from io.Reader.
|
||||
The second is error if Decode encounters any problems.
|
||||
|
||||
With these in mind, we can now do:
|
||||
|
||||
```
|
||||
// Create a new CONNECT message
|
||||
msg := NewConnectMessage()
|
||||
|
||||
// Set the appropriate parameters
|
||||
msg.SetWillQos(1)
|
||||
msg.SetVersion(4)
|
||||
msg.SetCleanSession(true)
|
||||
msg.SetClientId([]byte("surgemq"))
|
||||
msg.SetKeepAlive(10)
|
||||
msg.SetWillTopic([]byte("will"))
|
||||
msg.SetWillMessage([]byte("send me home"))
|
||||
msg.SetUsername([]byte("surgemq"))
|
||||
msg.SetPassword([]byte("verysecret"))
|
||||
|
||||
// Encode the message and get the io.Reader
|
||||
r, n, err := msg.Encode()
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write n bytes into the connection
|
||||
m, err := io.CopyN(conn, r, int64(n))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Sent %d bytes of %s message", m, msg.Name())
|
||||
```
|
||||
|
||||
To receive a CONNECT message from a connection, we can do:
|
||||
|
||||
```
|
||||
// Create a new CONNECT message
|
||||
msg := NewConnectMessage()
|
||||
|
||||
// Decode the message by reading from conn
|
||||
n, err := msg.Decode(conn)
|
||||
```
|
||||
|
||||
If you don't know what type of message is coming down the pipe, you can do something like this:
|
||||
|
||||
```
|
||||
// Create a buffered IO reader for the connection
|
||||
br := bufio.NewReader(conn)
|
||||
|
||||
// Peek at the first byte, which contains the message type
|
||||
b, err := br.Peek(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract the type from the first byte
|
||||
t := MessageType(b[0] >> 4)
|
||||
|
||||
// Create a new message
|
||||
msg, err := t.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode it from the bufio.Reader
|
||||
n, err := msg.Decode(br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
```
|
||||
@@ -1,168 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The CONNACK Packet is the packet sent by the Server in response to a CONNECT Packet
|
||||
// received from a Client. The first packet sent from the Server to the Client MUST
|
||||
// be a CONNACK Packet [MQTT-3.2.0-1].
|
||||
//
|
||||
// If the Client does not receive a CONNACK Packet from the Server within a reasonable
|
||||
// amount of time, the Client SHOULD close the Network Connection. A "reasonable" amount
|
||||
// of time depends on the type of application and the communications infrastructure.
|
||||
type ConnackMessage struct {
|
||||
header
|
||||
|
||||
sessionPresent bool
|
||||
returnCode ConnackCode
|
||||
}
|
||||
|
||||
var _ Message = (*ConnackMessage)(nil)
|
||||
|
||||
// NewConnackMessage creates a new CONNACK message
|
||||
func NewConnackMessage() *ConnackMessage {
|
||||
msg := &ConnackMessage{}
|
||||
msg.SetType(CONNACK)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// String returns a string representation of the CONNACK message
|
||||
func (this ConnackMessage) String() string {
|
||||
return fmt.Sprintf("%s, Session Present=%t, Return code=%q\n", this.header, this.sessionPresent, this.returnCode)
|
||||
}
|
||||
|
||||
// SessionPresent returns the session present flag value
|
||||
func (this *ConnackMessage) SessionPresent() bool {
|
||||
return this.sessionPresent
|
||||
}
|
||||
|
||||
// SetSessionPresent sets the value of the session present flag
|
||||
func (this *ConnackMessage) SetSessionPresent(v bool) {
|
||||
if v {
|
||||
this.sessionPresent = true
|
||||
} else {
|
||||
this.sessionPresent = false
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// ReturnCode returns the return code received for the CONNECT message. The return
|
||||
// type is an error
|
||||
func (this *ConnackMessage) ReturnCode() ConnackCode {
|
||||
return this.returnCode
|
||||
}
|
||||
|
||||
func (this *ConnackMessage) SetReturnCode(ret ConnackCode) {
|
||||
this.returnCode = ret
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
func (this *ConnackMessage) Len() int {
|
||||
if !this.dirty {
|
||||
return len(this.dbuf)
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.header.msglen() + ml
|
||||
}
|
||||
|
||||
func (this *ConnackMessage) Decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
n, err := this.header.decode(src)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
b := src[total]
|
||||
|
||||
if b&254 != 0 {
|
||||
return 0, fmt.Errorf("connack/Decode: Bits 7-1 in Connack Acknowledge Flags byte (1) are not 0")
|
||||
}
|
||||
|
||||
this.sessionPresent = b&0x1 == 1
|
||||
total++
|
||||
|
||||
b = src[total]
|
||||
|
||||
// Read return code
|
||||
if b > 5 {
|
||||
return 0, fmt.Errorf("connack/Decode: Invalid CONNACK return code (%d)", b)
|
||||
}
|
||||
|
||||
this.returnCode = ConnackCode(b)
|
||||
total++
|
||||
|
||||
this.dirty = false
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *ConnackMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("connack/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
// CONNACK remaining length fixed at 2 bytes
|
||||
hl := this.header.msglen()
|
||||
ml := this.msglen()
|
||||
|
||||
if len(dst) < hl+ml {
|
||||
return 0, fmt.Errorf("connack/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
||||
}
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
n, err := this.header.encode(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if this.sessionPresent {
|
||||
dst[total] = 1
|
||||
}
|
||||
total++
|
||||
|
||||
if this.returnCode > 5 {
|
||||
return total, fmt.Errorf("connack/Encode: Invalid CONNACK return code (%d)", this.returnCode)
|
||||
}
|
||||
|
||||
dst[total] = this.returnCode.Value()
|
||||
total++
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *ConnackMessage) msglen() int {
|
||||
return 2
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnackMessageFields(t *testing.T) {
|
||||
msg := NewConnackMessage()
|
||||
|
||||
msg.SetSessionPresent(true)
|
||||
require.True(t, msg.SessionPresent(), "Error setting session present flag.")
|
||||
|
||||
msg.SetSessionPresent(false)
|
||||
require.False(t, msg.SessionPresent(), "Error setting session present flag.")
|
||||
|
||||
msg.SetReturnCode(ConnectionAccepted)
|
||||
require.Equal(t, ConnectionAccepted, msg.ReturnCode(), "Error setting return code.")
|
||||
}
|
||||
|
||||
func TestConnackMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNACK << 4),
|
||||
2,
|
||||
0, // session not present
|
||||
0, // connection accepted
|
||||
}
|
||||
|
||||
msg := NewConnackMessage()
|
||||
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.False(t, msg.SessionPresent(), "Error decoding session present flag.")
|
||||
require.Equal(t, ConnectionAccepted, msg.ReturnCode(), "Error decoding return code.")
|
||||
}
|
||||
|
||||
// testing wrong message length
|
||||
func TestConnackMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNACK << 4),
|
||||
3,
|
||||
0, // session not present
|
||||
0, // connection accepted
|
||||
}
|
||||
|
||||
msg := NewConnackMessage()
|
||||
|
||||
_, err := msg.Decode(msgBytes)
|
||||
require.Error(t, err, "Error decoding message.")
|
||||
}
|
||||
|
||||
// testing wrong message size
|
||||
func TestConnackMessageDecode3(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNACK << 4),
|
||||
2,
|
||||
0, // session not present
|
||||
}
|
||||
|
||||
msg := NewConnackMessage()
|
||||
|
||||
_, err := msg.Decode(msgBytes)
|
||||
require.Error(t, err, "Error decoding message.")
|
||||
}
|
||||
|
||||
// testing wrong reserve bits
|
||||
func TestConnackMessageDecode4(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNACK << 4),
|
||||
2,
|
||||
64, // <- wrong size
|
||||
0, // connection accepted
|
||||
}
|
||||
|
||||
msg := NewConnackMessage()
|
||||
|
||||
_, err := msg.Decode(msgBytes)
|
||||
require.Error(t, err, "Error decoding message.")
|
||||
}
|
||||
|
||||
// testing invalid return code
|
||||
func TestConnackMessageDecode5(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNACK << 4),
|
||||
2,
|
||||
0,
|
||||
6, // <- wrong code
|
||||
}
|
||||
|
||||
msg := NewConnackMessage()
|
||||
|
||||
_, err := msg.Decode(msgBytes)
|
||||
require.Error(t, err, "Error decoding message.")
|
||||
}
|
||||
|
||||
func TestConnackMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNACK << 4),
|
||||
2,
|
||||
1, // session present
|
||||
0, // connection accepted
|
||||
}
|
||||
|
||||
msg := NewConnackMessage()
|
||||
msg.SetReturnCode(ConnectionAccepted)
|
||||
msg.SetSessionPresent(true)
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error encoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error encoding connack message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestConnackDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNACK << 4),
|
||||
2,
|
||||
0, // session not present
|
||||
0, // connection accepted
|
||||
}
|
||||
|
||||
msg := NewConnackMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
// ConnackCode is the type representing the return code in the CONNACK message,
|
||||
// returned after the initial CONNECT message
|
||||
type ConnackCode byte
|
||||
|
||||
const (
|
||||
// Connection accepted
|
||||
ConnectionAccepted ConnackCode = iota
|
||||
|
||||
// The Server does not support the level of the MQTT protocol requested by the Client
|
||||
ErrInvalidProtocolVersion
|
||||
|
||||
// The Client identifier is correct UTF-8 but not allowed by the server
|
||||
ErrIdentifierRejected
|
||||
|
||||
// The Network Connection has been made but the MQTT service is unavailable
|
||||
ErrServerUnavailable
|
||||
|
||||
// The data in the user name or password is malformed
|
||||
ErrBadUsernameOrPassword
|
||||
|
||||
// The Client is not authorized to connect
|
||||
ErrNotAuthorized
|
||||
)
|
||||
|
||||
// Value returns the value of the ConnackCode, which is just the byte representation
|
||||
func (this ConnackCode) Value() byte {
|
||||
return byte(this)
|
||||
}
|
||||
|
||||
// Desc returns the description of the ConnackCode
|
||||
func (this ConnackCode) Desc() string {
|
||||
switch this {
|
||||
case 0:
|
||||
return "Connection accepted"
|
||||
case 1:
|
||||
return "The Server does not support the level of the MQTT protocol requested by the Client"
|
||||
case 2:
|
||||
return "The Client identifier is correct UTF-8 but not allowed by the server"
|
||||
case 3:
|
||||
return "The Network Connection has been made but the MQTT service is unavailable"
|
||||
case 4:
|
||||
return "The data in the user name or password is malformed"
|
||||
case 5:
|
||||
return "The Client is not authorized to connect"
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// Valid checks to see if the ConnackCode is valid. Currently valid codes are <= 5
|
||||
func (this ConnackCode) Valid() bool {
|
||||
return this <= 5
|
||||
}
|
||||
|
||||
// Error returns the corresonding error string for the ConnackCode
|
||||
func (this ConnackCode) Error() string {
|
||||
switch this {
|
||||
case 0:
|
||||
return "Connection accepted"
|
||||
case 1:
|
||||
return "Connection Refused, unacceptable protocol version"
|
||||
case 2:
|
||||
return "Connection Refused, identifier rejected"
|
||||
case 3:
|
||||
return "Connection Refused, Server unavailable"
|
||||
case 4:
|
||||
return "Connection Refused, bad user name or password"
|
||||
case 5:
|
||||
return "Connection Refused, not authorized"
|
||||
}
|
||||
|
||||
return "Unknown error"
|
||||
}
|
||||
@@ -1,635 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
var clientIdRegexp *regexp.Regexp
|
||||
|
||||
func init() {
|
||||
// Added space for Paho compliance test
|
||||
// Added underscore (_) for MQTT C client test
|
||||
clientIdRegexp = regexp.MustCompile("^[0-9a-zA-Z _]*$")
|
||||
}
|
||||
|
||||
// After a Network Connection is established by a Client to a Server, the first Packet
|
||||
// sent from the Client to the Server MUST be a CONNECT Packet [MQTT-3.1.0-1].
|
||||
//
|
||||
// A Client can only send the CONNECT Packet once over a Network Connection. The Server
|
||||
// MUST process a second CONNECT Packet sent from a Client as a protocol violation and
|
||||
// disconnect the Client [MQTT-3.1.0-2]. See section 4.8 for information about
|
||||
// handling errors.
|
||||
type ConnectMessage struct {
|
||||
header
|
||||
|
||||
// 7: username flag
|
||||
// 6: password flag
|
||||
// 5: will retain
|
||||
// 4-3: will QoS
|
||||
// 2: will flag
|
||||
// 1: clean session
|
||||
// 0: reserved
|
||||
connectFlags byte
|
||||
|
||||
version byte
|
||||
|
||||
keepAlive uint16
|
||||
|
||||
protoName,
|
||||
clientId,
|
||||
willTopic,
|
||||
willMessage,
|
||||
username,
|
||||
password []byte
|
||||
}
|
||||
|
||||
var _ Message = (*ConnectMessage)(nil)
|
||||
|
||||
// NewConnectMessage creates a new CONNECT message.
|
||||
func NewConnectMessage() *ConnectMessage {
|
||||
msg := &ConnectMessage{}
|
||||
msg.SetType(CONNECT)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// String returns a string representation of the CONNECT message
|
||||
func (this ConnectMessage) String() string {
|
||||
return fmt.Sprintf("%s, Connect Flags=%08b, Version=%d, KeepAlive=%d, Client ID=%q, Will Topic=%q, Will Message=%q, Username=%q, Password=%q",
|
||||
this.header,
|
||||
this.connectFlags,
|
||||
this.Version(),
|
||||
this.KeepAlive(),
|
||||
this.ClientId(),
|
||||
this.WillTopic(),
|
||||
this.WillMessage(),
|
||||
this.Username(),
|
||||
this.Password(),
|
||||
)
|
||||
}
|
||||
|
||||
// Version returns the the 8 bit unsigned value that represents the revision level
|
||||
// of the protocol used by the Client. The value of the Protocol Level field for
|
||||
// the version 3.1.1 of the protocol is 4 (0x04).
|
||||
func (this *ConnectMessage) Version() byte {
|
||||
return this.version
|
||||
}
|
||||
|
||||
// SetVersion sets the version value of the CONNECT message
|
||||
func (this *ConnectMessage) SetVersion(v byte) error {
|
||||
if _, ok := SupportedVersions[v]; !ok {
|
||||
return fmt.Errorf("connect/SetVersion: Invalid version number %d", v)
|
||||
}
|
||||
|
||||
this.version = v
|
||||
this.dirty = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CleanSession returns the bit that specifies the handling of the Session state.
|
||||
// The Client and Server can store Session state to enable reliable messaging to
|
||||
// continue across a sequence of Network Connections. This bit is used to control
|
||||
// the lifetime of the Session state.
|
||||
func (this *ConnectMessage) CleanSession() bool {
|
||||
return ((this.connectFlags >> 1) & 0x1) == 1
|
||||
}
|
||||
|
||||
// SetCleanSession sets the bit that specifies the handling of the Session state.
|
||||
func (this *ConnectMessage) SetCleanSession(v bool) {
|
||||
if v {
|
||||
this.connectFlags |= 0x2 // 00000010
|
||||
} else {
|
||||
this.connectFlags &= 253 // 11111101
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// WillFlag returns the bit that specifies whether a Will Message should be stored
|
||||
// on the server. If the Will Flag is set to 1 this indicates that, if the Connect
|
||||
// request is accepted, a Will Message MUST be stored on the Server and associated
|
||||
// with the Network Connection.
|
||||
func (this *ConnectMessage) WillFlag() bool {
|
||||
return ((this.connectFlags >> 2) & 0x1) == 1
|
||||
}
|
||||
|
||||
// SetWillFlag sets the bit that specifies whether a Will Message should be stored
|
||||
// on the server.
|
||||
func (this *ConnectMessage) SetWillFlag(v bool) {
|
||||
if v {
|
||||
this.connectFlags |= 0x4 // 00000100
|
||||
} else {
|
||||
this.connectFlags &= 251 // 11111011
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// WillQos returns the two bits that specify the QoS level to be used when publishing
|
||||
// the Will Message.
|
||||
func (this *ConnectMessage) WillQos() byte {
|
||||
return (this.connectFlags >> 3) & 0x3
|
||||
}
|
||||
|
||||
// SetWillQos sets the two bits that specify the QoS level to be used when publishing
|
||||
// the Will Message.
|
||||
func (this *ConnectMessage) SetWillQos(qos byte) error {
|
||||
if qos != QosAtMostOnce && qos != QosAtLeastOnce && qos != QosExactlyOnce {
|
||||
return fmt.Errorf("connect/SetWillQos: Invalid QoS level %d", qos)
|
||||
}
|
||||
|
||||
this.connectFlags = (this.connectFlags & 231) | (qos << 3) // 231 = 11100111
|
||||
this.dirty = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WillRetain returns the bit specifies if the Will Message is to be Retained when it
|
||||
// is published.
|
||||
func (this *ConnectMessage) WillRetain() bool {
|
||||
return ((this.connectFlags >> 5) & 0x1) == 1
|
||||
}
|
||||
|
||||
// SetWillRetain sets the bit specifies if the Will Message is to be Retained when it
|
||||
// is published.
|
||||
func (this *ConnectMessage) SetWillRetain(v bool) {
|
||||
if v {
|
||||
this.connectFlags |= 32 // 00100000
|
||||
} else {
|
||||
this.connectFlags &= 223 // 11011111
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// UsernameFlag returns the bit that specifies whether a user name is present in the
|
||||
// payload.
|
||||
func (this *ConnectMessage) UsernameFlag() bool {
|
||||
return ((this.connectFlags >> 7) & 0x1) == 1
|
||||
}
|
||||
|
||||
// SetUsernameFlag sets the bit that specifies whether a user name is present in the
|
||||
// payload.
|
||||
func (this *ConnectMessage) SetUsernameFlag(v bool) {
|
||||
if v {
|
||||
this.connectFlags |= 128 // 10000000
|
||||
} else {
|
||||
this.connectFlags &= 127 // 01111111
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// PasswordFlag returns the bit that specifies whether a password is present in the
|
||||
// payload.
|
||||
func (this *ConnectMessage) PasswordFlag() bool {
|
||||
return ((this.connectFlags >> 6) & 0x1) == 1
|
||||
}
|
||||
|
||||
// SetPasswordFlag sets the bit that specifies whether a password is present in the
|
||||
// payload.
|
||||
func (this *ConnectMessage) SetPasswordFlag(v bool) {
|
||||
if v {
|
||||
this.connectFlags |= 64 // 01000000
|
||||
} else {
|
||||
this.connectFlags &= 191 // 10111111
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// KeepAlive returns a time interval measured in seconds. Expressed as a 16-bit word,
|
||||
// it is the maximum time interval that is permitted to elapse between the point at
|
||||
// which the Client finishes transmitting one Control Packet and the point it starts
|
||||
// sending the next.
|
||||
func (this *ConnectMessage) KeepAlive() uint16 {
|
||||
return this.keepAlive
|
||||
}
|
||||
|
||||
// SetKeepAlive sets the time interval in which the server should keep the connection
|
||||
// alive.
|
||||
func (this *ConnectMessage) SetKeepAlive(v uint16) {
|
||||
this.keepAlive = v
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// ClientId returns an ID that identifies the Client to the Server. Each Client
|
||||
// connecting to the Server has a unique ClientId. The ClientId MUST be used by
|
||||
// Clients and by Servers to identify state that they hold relating to this MQTT
|
||||
// Session between the Client and the Server
|
||||
func (this *ConnectMessage) ClientId() []byte {
|
||||
return this.clientId
|
||||
}
|
||||
|
||||
// SetClientId sets an ID that identifies the Client to the Server.
|
||||
func (this *ConnectMessage) SetClientId(v []byte) error {
|
||||
if len(v) > 0 && !this.validClientId(v) {
|
||||
return ErrIdentifierRejected
|
||||
}
|
||||
|
||||
this.clientId = v
|
||||
this.dirty = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// WillTopic returns the topic in which the Will Message should be published to.
|
||||
// If the Will Flag is set to 1, the Will Topic must be in the payload.
|
||||
func (this *ConnectMessage) WillTopic() []byte {
|
||||
return this.willTopic
|
||||
}
|
||||
|
||||
// SetWillTopic sets the topic in which the Will Message should be published to.
|
||||
func (this *ConnectMessage) SetWillTopic(v []byte) {
|
||||
this.willTopic = v
|
||||
|
||||
if len(v) > 0 {
|
||||
this.SetWillFlag(true)
|
||||
} else if len(this.willMessage) == 0 {
|
||||
this.SetWillFlag(false)
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// WillMessage returns the Will Message that is to be published to the Will Topic.
|
||||
func (this *ConnectMessage) WillMessage() []byte {
|
||||
return this.willMessage
|
||||
}
|
||||
|
||||
// SetWillMessage sets the Will Message that is to be published to the Will Topic.
|
||||
func (this *ConnectMessage) SetWillMessage(v []byte) {
|
||||
this.willMessage = v
|
||||
|
||||
if len(v) > 0 {
|
||||
this.SetWillFlag(true)
|
||||
} else if len(this.willTopic) == 0 {
|
||||
this.SetWillFlag(false)
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// Username returns the username from the payload. If the User Name Flag is set to 1,
|
||||
// this must be in the payload. It can be used by the Server for authentication and
|
||||
// authorization.
|
||||
func (this *ConnectMessage) Username() []byte {
|
||||
return this.username
|
||||
}
|
||||
|
||||
// SetUsername sets the username for authentication.
|
||||
func (this *ConnectMessage) SetUsername(v []byte) {
|
||||
this.username = v
|
||||
|
||||
if len(v) > 0 {
|
||||
this.SetUsernameFlag(true)
|
||||
} else {
|
||||
this.SetUsernameFlag(false)
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// Password returns the password from the payload. If the Password Flag is set to 1,
|
||||
// this must be in the payload. It can be used by the Server for authentication and
|
||||
// authorization.
|
||||
func (this *ConnectMessage) Password() []byte {
|
||||
return this.password
|
||||
}
|
||||
|
||||
// SetPassword sets the username for authentication.
|
||||
func (this *ConnectMessage) SetPassword(v []byte) {
|
||||
this.password = v
|
||||
|
||||
if len(v) > 0 {
|
||||
this.SetPasswordFlag(true)
|
||||
} else {
|
||||
this.SetPasswordFlag(false)
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
func (this *ConnectMessage) Len() int {
|
||||
if !this.dirty {
|
||||
return len(this.dbuf)
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.header.msglen() + ml
|
||||
}
|
||||
|
||||
// For the CONNECT message, the error returned could be a ConnackReturnCode, so
|
||||
// be sure to check that. Otherwise it's a generic error. If a generic error is
|
||||
// returned, this Message should be considered invalid.
|
||||
//
|
||||
// Caller should call ValidConnackError(err) to see if the returned error is
|
||||
// a Connack error. If so, caller should send the Client back the corresponding
|
||||
// CONNACK message.
|
||||
func (this *ConnectMessage) Decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
n, err := this.header.decode(src[total:])
|
||||
if err != nil {
|
||||
return total + n, err
|
||||
}
|
||||
total += n
|
||||
|
||||
if n, err = this.decodeMessage(src[total:]); err != nil {
|
||||
return total + n, err
|
||||
}
|
||||
total += n
|
||||
|
||||
this.dirty = false
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *ConnectMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("connect/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
if this.Type() != CONNECT {
|
||||
return 0, fmt.Errorf("connect/Encode: Invalid message type. Expecting %d, got %d", CONNECT, this.Type())
|
||||
}
|
||||
|
||||
_, ok := SupportedVersions[this.version]
|
||||
if !ok {
|
||||
return 0, ErrInvalidProtocolVersion
|
||||
}
|
||||
|
||||
hl := this.header.msglen()
|
||||
ml := this.msglen()
|
||||
|
||||
if len(dst) < hl+ml {
|
||||
return 0, fmt.Errorf("connect/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
||||
}
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
n, err := this.header.encode(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
n, err = this.encodeMessage(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *ConnectMessage) encodeMessage(dst []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
n, err := writeLPBytes(dst[total:], []byte(SupportedVersions[this.version]))
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
dst[total] = this.version
|
||||
total += 1
|
||||
|
||||
dst[total] = this.connectFlags
|
||||
total += 1
|
||||
|
||||
binary.BigEndian.PutUint16(dst[total:], this.keepAlive)
|
||||
total += 2
|
||||
|
||||
n, err = writeLPBytes(dst[total:], this.clientId)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
if this.WillFlag() {
|
||||
n, err = writeLPBytes(dst[total:], this.willTopic)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
n, err = writeLPBytes(dst[total:], this.willMessage)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
// According to the 3.1 spec, it's possible that the usernameFlag is set,
|
||||
// but the username string is missing.
|
||||
if this.UsernameFlag() && len(this.username) > 0 {
|
||||
n, err = writeLPBytes(dst[total:], this.username)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
||||
// but the password string is missing.
|
||||
if this.PasswordFlag() && len(this.password) > 0 {
|
||||
n, err = writeLPBytes(dst[total:], this.password)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *ConnectMessage) decodeMessage(src []byte) (int, error) {
|
||||
var err error
|
||||
n, total := 0, 0
|
||||
|
||||
this.protoName, n, err = readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
this.version = src[total]
|
||||
total++
|
||||
|
||||
if verstr, ok := SupportedVersions[this.version]; !ok {
|
||||
return total, ErrInvalidProtocolVersion
|
||||
} else if verstr != string(this.protoName) {
|
||||
return total, ErrInvalidProtocolVersion
|
||||
}
|
||||
|
||||
this.connectFlags = src[total]
|
||||
total++
|
||||
|
||||
if this.connectFlags&0x1 != 0 {
|
||||
return total, fmt.Errorf("connect/decodeMessage: Connect Flags reserved bit 0 is not 0")
|
||||
}
|
||||
|
||||
if this.WillQos() > QosExactlyOnce {
|
||||
return total, fmt.Errorf("connect/decodeMessage: Invalid QoS level (%d) for %s message", this.WillQos(), this.Name())
|
||||
}
|
||||
|
||||
if !this.WillFlag() && (this.WillRetain() || this.WillQos() != QosAtMostOnce) {
|
||||
return total, fmt.Errorf("connect/decodeMessage: Protocol violation: If the Will Flag (%t) is set to 0 the Will QoS (%d) and Will Retain (%t) fields MUST be set to zero", this.WillFlag(), this.WillQos(), this.WillRetain())
|
||||
}
|
||||
|
||||
if this.UsernameFlag() && !this.PasswordFlag() {
|
||||
return total, fmt.Errorf("connect/decodeMessage: Username flag is set but Password flag is not set")
|
||||
}
|
||||
|
||||
if len(src[total:]) < 2 {
|
||||
return 0, fmt.Errorf("connect/decodeMessage: Insufficient buffer size. Expecting %d, got %d.", 2, len(src[total:]))
|
||||
}
|
||||
|
||||
this.keepAlive = binary.BigEndian.Uint16(src[total:])
|
||||
total += 2
|
||||
|
||||
this.clientId, n, err = readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
// If the Client supplies a zero-byte ClientId, the Client MUST also set CleanSession to 1
|
||||
if len(this.clientId) == 0 && !this.CleanSession() {
|
||||
return total, ErrIdentifierRejected
|
||||
}
|
||||
|
||||
// The ClientId must contain only characters 0-9, a-z, and A-Z
|
||||
// We also support ClientId longer than 23 encoded bytes
|
||||
// We do not support ClientId outside of the above characters
|
||||
if len(this.clientId) > 0 && !this.validClientId(this.clientId) {
|
||||
return total, ErrIdentifierRejected
|
||||
}
|
||||
|
||||
if this.WillFlag() {
|
||||
this.willTopic, n, err = readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
this.willMessage, n, err = readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
||||
// but the password string is missing.
|
||||
if this.UsernameFlag() && len(src[total:]) > 0 {
|
||||
this.username, n, err = readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
||||
// but the password string is missing.
|
||||
if this.PasswordFlag() && len(src[total:]) > 0 {
|
||||
this.password, n, err = readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *ConnectMessage) msglen() int {
|
||||
total := 0
|
||||
|
||||
verstr, ok := SupportedVersions[this.version]
|
||||
if !ok {
|
||||
return total
|
||||
}
|
||||
|
||||
// 2 bytes protocol name length
|
||||
// n bytes protocol name
|
||||
// 1 byte protocol version
|
||||
// 1 byte connect flags
|
||||
// 2 bytes keep alive timer
|
||||
total += 2 + len(verstr) + 1 + 1 + 2
|
||||
|
||||
// Add the clientID length, 2 is the length prefix
|
||||
total += 2 + len(this.clientId)
|
||||
|
||||
// Add the will topic and will message length, and the length prefixes
|
||||
if this.WillFlag() {
|
||||
total += 2 + len(this.willTopic) + 2 + len(this.willMessage)
|
||||
}
|
||||
|
||||
// Add the username length
|
||||
// According to the 3.1 spec, it's possible that the usernameFlag is set,
|
||||
// but the user name string is missing.
|
||||
if this.UsernameFlag() && len(this.username) > 0 {
|
||||
total += 2 + len(this.username)
|
||||
}
|
||||
|
||||
// Add the password length
|
||||
// According to the 3.1 spec, it's possible that the passwordFlag is set,
|
||||
// but the password string is missing.
|
||||
if this.PasswordFlag() && len(this.password) > 0 {
|
||||
total += 2 + len(this.password)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
|
||||
// validClientId checks the client ID, which is a slice of bytes, to see if it's valid.
|
||||
// Client ID is valid if it meets the requirement from the MQTT spec:
|
||||
// The Server MUST allow ClientIds which are between 1 and 23 UTF-8 encoded bytes in length,
|
||||
// and that contain only the characters
|
||||
//
|
||||
// "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"
|
||||
func (this *ConnectMessage) validClientId(cid []byte) bool {
|
||||
// Fixed https://github.com/surgemq/surgemq/issues/4
|
||||
//if len(cid) > 23 {
|
||||
// return false
|
||||
//}
|
||||
|
||||
if this.Version() == 0x3 {
|
||||
return true
|
||||
}
|
||||
|
||||
return clientIdRegexp.Match(cid)
|
||||
}
|
||||
@@ -1,373 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestConnectMessageFields(t *testing.T) {
|
||||
msg := NewConnectMessage()
|
||||
|
||||
err := msg.SetVersion(0x3)
|
||||
require.NoError(t, err, "Error setting message version.")
|
||||
|
||||
require.Equal(t, 0x3, int(msg.Version()), "Incorrect version number")
|
||||
|
||||
err = msg.SetVersion(0x5)
|
||||
require.Error(t, err)
|
||||
|
||||
msg.SetCleanSession(true)
|
||||
require.True(t, msg.CleanSession(), "Error setting clean session flag.")
|
||||
|
||||
msg.SetCleanSession(false)
|
||||
require.False(t, msg.CleanSession(), "Error setting clean session flag.")
|
||||
|
||||
msg.SetWillFlag(true)
|
||||
require.True(t, msg.WillFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetWillFlag(false)
|
||||
require.False(t, msg.WillFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetWillRetain(true)
|
||||
require.True(t, msg.WillRetain(), "Error setting will retain.")
|
||||
|
||||
msg.SetWillRetain(false)
|
||||
require.False(t, msg.WillRetain(), "Error setting will retain.")
|
||||
|
||||
msg.SetPasswordFlag(true)
|
||||
require.True(t, msg.PasswordFlag(), "Error setting password flag.")
|
||||
|
||||
msg.SetPasswordFlag(false)
|
||||
require.False(t, msg.PasswordFlag(), "Error setting password flag.")
|
||||
|
||||
msg.SetUsernameFlag(true)
|
||||
require.True(t, msg.UsernameFlag(), "Error setting username flag.")
|
||||
|
||||
msg.SetUsernameFlag(false)
|
||||
require.False(t, msg.UsernameFlag(), "Error setting username flag.")
|
||||
|
||||
msg.SetWillQos(1)
|
||||
require.Equal(t, 1, int(msg.WillQos()), "Error setting will QoS.")
|
||||
|
||||
err = msg.SetWillQos(4)
|
||||
require.Error(t, err)
|
||||
|
||||
err = msg.SetClientId([]byte("j0j0jfajf02j0asdjf"))
|
||||
require.NoError(t, err, "Error setting client ID")
|
||||
|
||||
require.Equal(t, "j0j0jfajf02j0asdjf", string(msg.ClientId()), "Error setting client ID.")
|
||||
|
||||
err = msg.SetClientId([]byte("this is good for v3"))
|
||||
require.NoError(t, err)
|
||||
|
||||
msg.SetVersion(0x4)
|
||||
|
||||
err = msg.SetClientId([]byte("this is no good for v4!"))
|
||||
require.Error(t, err)
|
||||
|
||||
msg.SetVersion(0x3)
|
||||
|
||||
msg.SetWillTopic([]byte("willtopic"))
|
||||
require.Equal(t, "willtopic", string(msg.WillTopic()), "Error setting will topic.")
|
||||
|
||||
require.True(t, msg.WillFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetWillTopic([]byte(""))
|
||||
require.Equal(t, "", string(msg.WillTopic()), "Error setting will topic.")
|
||||
|
||||
require.False(t, msg.WillFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetWillMessage([]byte("this is a will message"))
|
||||
require.Equal(t, "this is a will message", string(msg.WillMessage()), "Error setting will message.")
|
||||
|
||||
require.True(t, msg.WillFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetWillMessage([]byte(""))
|
||||
require.Equal(t, "", string(msg.WillMessage()), "Error setting will topic.")
|
||||
|
||||
require.False(t, msg.WillFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetWillTopic([]byte("willtopic"))
|
||||
msg.SetWillMessage([]byte("this is a will message"))
|
||||
msg.SetWillTopic([]byte(""))
|
||||
require.True(t, msg.WillFlag(), "Error setting will topic.")
|
||||
|
||||
msg.SetUsername([]byte("myname"))
|
||||
require.Equal(t, "myname", string(msg.Username()), "Error setting will message.")
|
||||
|
||||
require.True(t, msg.UsernameFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetUsername([]byte(""))
|
||||
require.Equal(t, "", string(msg.Username()), "Error setting will message.")
|
||||
|
||||
require.False(t, msg.UsernameFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetPassword([]byte("myname"))
|
||||
require.Equal(t, "myname", string(msg.Password()), "Error setting will message.")
|
||||
|
||||
require.True(t, msg.PasswordFlag(), "Error setting will flag.")
|
||||
|
||||
msg.SetPassword([]byte(""))
|
||||
require.Equal(t, "", string(msg.Password()), "Error setting will message.")
|
||||
|
||||
require.False(t, msg.PasswordFlag(), "Error setting will flag.")
|
||||
}
|
||||
|
||||
func TestConnectMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNECT << 4),
|
||||
60,
|
||||
0, // Length MSB (0)
|
||||
4, // Length LSB (4)
|
||||
'M', 'Q', 'T', 'T',
|
||||
4, // Protocol level 4
|
||||
206, // connect flags 11001110, will QoS = 01
|
||||
0, // Keep Alive MSB (0)
|
||||
10, // Keep Alive LSB (10)
|
||||
0, // Client ID MSB (0)
|
||||
7, // Client ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Will Topic MSB (0)
|
||||
4, // Will Topic LSB (4)
|
||||
'w', 'i', 'l', 'l',
|
||||
0, // Will Message MSB (0)
|
||||
12, // Will Message LSB (12)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
0, // Username ID MSB (0)
|
||||
7, // Username ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Password ID MSB (0)
|
||||
10, // Password ID LSB (10)
|
||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
||||
}
|
||||
|
||||
msg := NewConnectMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, 206, int(msg.connectFlags), "Incorrect flag value.")
|
||||
require.Equal(t, 10, int(msg.KeepAlive()), "Incorrect KeepAlive value.")
|
||||
require.Equal(t, "surgemq", string(msg.ClientId()), "Incorrect client ID value.")
|
||||
require.Equal(t, "will", string(msg.WillTopic()), "Incorrect will topic value.")
|
||||
require.Equal(t, "send me home", string(msg.WillMessage()), "Incorrect will message value.")
|
||||
require.Equal(t, "surgemq", string(msg.Username()), "Incorrect username value.")
|
||||
require.Equal(t, "verysecret", string(msg.Password()), "Incorrect password value.")
|
||||
}
|
||||
|
||||
func TestConnectMessageDecode2(t *testing.T) {
|
||||
// missing last byte 't'
|
||||
msgBytes := []byte{
|
||||
byte(CONNECT << 4),
|
||||
60,
|
||||
0, // Length MSB (0)
|
||||
4, // Length LSB (4)
|
||||
'M', 'Q', 'T', 'T',
|
||||
4, // Protocol level 4
|
||||
206, // connect flags 11001110, will QoS = 01
|
||||
0, // Keep Alive MSB (0)
|
||||
10, // Keep Alive LSB (10)
|
||||
0, // Client ID MSB (0)
|
||||
7, // Client ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Will Topic MSB (0)
|
||||
4, // Will Topic LSB (4)
|
||||
'w', 'i', 'l', 'l',
|
||||
0, // Will Message MSB (0)
|
||||
12, // Will Message LSB (12)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
0, // Username ID MSB (0)
|
||||
7, // Username ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Password ID MSB (0)
|
||||
10, // Password ID LSB (10)
|
||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e',
|
||||
}
|
||||
|
||||
msg := NewConnectMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConnectMessageDecode3(t *testing.T) {
|
||||
// extra bytes
|
||||
msgBytes := []byte{
|
||||
byte(CONNECT << 4),
|
||||
60,
|
||||
0, // Length MSB (0)
|
||||
4, // Length LSB (4)
|
||||
'M', 'Q', 'T', 'T',
|
||||
4, // Protocol level 4
|
||||
206, // connect flags 11001110, will QoS = 01
|
||||
0, // Keep Alive MSB (0)
|
||||
10, // Keep Alive LSB (10)
|
||||
0, // Client ID MSB (0)
|
||||
7, // Client ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Will Topic MSB (0)
|
||||
4, // Will Topic LSB (4)
|
||||
'w', 'i', 'l', 'l',
|
||||
0, // Will Message MSB (0)
|
||||
12, // Will Message LSB (12)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
0, // Username ID MSB (0)
|
||||
7, // Username ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Password ID MSB (0)
|
||||
10, // Password ID LSB (10)
|
||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
||||
'e', 'x', 't', 'r', 'a',
|
||||
}
|
||||
|
||||
msg := NewConnectMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 62, n)
|
||||
}
|
||||
|
||||
func TestConnectMessageDecode4(t *testing.T) {
|
||||
// missing client Id, clean session == 0
|
||||
msgBytes := []byte{
|
||||
byte(CONNECT << 4),
|
||||
53,
|
||||
0, // Length MSB (0)
|
||||
4, // Length LSB (4)
|
||||
'M', 'Q', 'T', 'T',
|
||||
4, // Protocol level 4
|
||||
204, // connect flags 11001110, will QoS = 01
|
||||
0, // Keep Alive MSB (0)
|
||||
10, // Keep Alive LSB (10)
|
||||
0, // Client ID MSB (0)
|
||||
0, // Client ID LSB (0)
|
||||
0, // Will Topic MSB (0)
|
||||
4, // Will Topic LSB (4)
|
||||
'w', 'i', 'l', 'l',
|
||||
0, // Will Message MSB (0)
|
||||
12, // Will Message LSB (12)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
0, // Username ID MSB (0)
|
||||
7, // Username ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Password ID MSB (0)
|
||||
10, // Password ID LSB (10)
|
||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
||||
}
|
||||
|
||||
msg := NewConnectMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestConnectMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNECT << 4),
|
||||
60,
|
||||
0, // Length MSB (0)
|
||||
4, // Length LSB (4)
|
||||
'M', 'Q', 'T', 'T',
|
||||
4, // Protocol level 4
|
||||
206, // connect flags 11001110, will QoS = 01
|
||||
0, // Keep Alive MSB (0)
|
||||
10, // Keep Alive LSB (10)
|
||||
0, // Client ID MSB (0)
|
||||
7, // Client ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Will Topic MSB (0)
|
||||
4, // Will Topic LSB (4)
|
||||
'w', 'i', 'l', 'l',
|
||||
0, // Will Message MSB (0)
|
||||
12, // Will Message LSB (12)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
0, // Username ID MSB (0)
|
||||
7, // Username ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Password ID MSB (0)
|
||||
10, // Password ID LSB (10)
|
||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
||||
}
|
||||
|
||||
msg := NewConnectMessage()
|
||||
msg.SetWillQos(1)
|
||||
msg.SetVersion(4)
|
||||
msg.SetCleanSession(true)
|
||||
msg.SetClientId([]byte("surgemq"))
|
||||
msg.SetKeepAlive(10)
|
||||
msg.SetWillTopic([]byte("will"))
|
||||
msg.SetWillMessage([]byte("send me home"))
|
||||
msg.SetUsername([]byte("surgemq"))
|
||||
msg.SetPassword([]byte("verysecret"))
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestConnectDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(CONNECT << 4),
|
||||
60,
|
||||
0, // Length MSB (0)
|
||||
4, // Length LSB (4)
|
||||
'M', 'Q', 'T', 'T',
|
||||
4, // Protocol level 4
|
||||
206, // connect flags 11001110, will QoS = 01
|
||||
0, // Keep Alive MSB (0)
|
||||
10, // Keep Alive LSB (10)
|
||||
0, // Client ID MSB (0)
|
||||
7, // Client ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Will Topic MSB (0)
|
||||
4, // Will Topic LSB (4)
|
||||
'w', 'i', 'l', 'l',
|
||||
0, // Will Message MSB (0)
|
||||
12, // Will Message LSB (12)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
0, // Username ID MSB (0)
|
||||
7, // Username ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Password ID MSB (0)
|
||||
10, // Password ID LSB (10)
|
||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
||||
}
|
||||
|
||||
msg := NewConnectMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import "fmt"
|
||||
|
||||
// The DISCONNECT Packet is the final Control Packet sent from the Client to the Server.
|
||||
// It indicates that the Client is disconnecting cleanly.
|
||||
type DisconnectMessage struct {
|
||||
header
|
||||
}
|
||||
|
||||
var _ Message = (*DisconnectMessage)(nil)
|
||||
|
||||
// NewDisconnectMessage creates a new DISCONNECT message.
|
||||
func NewDisconnectMessage() *DisconnectMessage {
|
||||
msg := &DisconnectMessage{}
|
||||
msg.SetType(DISCONNECT)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (this *DisconnectMessage) Decode(src []byte) (int, error) {
|
||||
return this.header.decode(src)
|
||||
}
|
||||
|
||||
func (this *DisconnectMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("disconnect/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
return this.header.encode(dst)
|
||||
}
|
||||
@@ -1,78 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDisconnectMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(DISCONNECT << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewDisconnectMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, DISCONNECT, msg.Type(), "Error decoding message.")
|
||||
}
|
||||
|
||||
func TestDisconnectMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(DISCONNECT << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewDisconnectMessage()
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestDisconnectDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(DISCONNECT << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewDisconnectMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,141 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
/*
|
||||
Package message is an encoder/decoder library for MQTT 3.1 and 3.1.1 messages. You can
|
||||
find the MQTT specs at the following locations:
|
||||
|
||||
3.1.1 - http://docs.oasis-open.org/mqtt/mqtt/v3.1.1/
|
||||
3.1 - http://public.dhe.ibm.com/software/dw/webservices/ws-mqtt/mqtt-v3r1.html
|
||||
|
||||
From the spec:
|
||||
|
||||
MQTT is a Client Server publish/subscribe messaging transport protocol. It is
|
||||
light weight, open, simple, and designed so as to be easy to implement. These
|
||||
characteristics make it ideal for use in many situations, including constrained
|
||||
environments such as for communication in Machine to Machine (M2M) and Internet
|
||||
of Things (IoT) contexts where a small code footprint is required and/or network
|
||||
bandwidth is at a premium.
|
||||
|
||||
The MQTT protocol works by exchanging a series of MQTT messages in a defined way.
|
||||
The protocol runs over TCP/IP, or over other network protocols that provide
|
||||
ordered, lossless, bi-directional connections.
|
||||
|
||||
|
||||
There are two main items to take note in this package. The first is
|
||||
|
||||
type MessageType byte
|
||||
|
||||
MessageType is the type representing the MQTT packet types. In the MQTT spec, MQTT
|
||||
control packet type is represented as a 4-bit unsigned value. MessageType receives
|
||||
several methods that returns string representations of the names and descriptions.
|
||||
|
||||
Also, one of the methods is New(). It returns a new Message object based on the mtype
|
||||
parameter. For example:
|
||||
|
||||
m, err := CONNECT.New()
|
||||
msg := m.(*ConnectMessage)
|
||||
|
||||
This would return a PublishMessage struct, but mapped to the Message interface. You can
|
||||
then type assert it back to a *PublishMessage. Another way to create a new
|
||||
PublishMessage is to call
|
||||
|
||||
msg := NewConnectMessage()
|
||||
|
||||
Every message type has a New function that returns a new message. The list of available
|
||||
message types are defined as constants below.
|
||||
|
||||
As you may have noticed, the second important item is the Message interface. It defines
|
||||
several methods that are common to all messages, including Name(), Desc(), and Type().
|
||||
Most importantly, it also defines the Encode() and Decode() methods.
|
||||
|
||||
Encode() (io.Reader, int, error)
|
||||
Decode(io.Reader) (int, error)
|
||||
|
||||
Encode returns an io.Reader in which the encoded bytes can be read. The second return
|
||||
value is the number of bytes encoded, so the caller knows how many bytes there will be.
|
||||
If Encode returns an error, then the first two return values should be considered invalid.
|
||||
Any changes to the message after Encode() is called will invalidate the io.Reader.
|
||||
|
||||
Decode reads from the io.Reader parameter until a full message is decoded, or when io.Reader
|
||||
returns EOF or error. The first return value is the number of bytes read from io.Reader.
|
||||
The second is error if Decode encounters any problems.
|
||||
|
||||
With these in mind, we can now do:
|
||||
|
||||
// Create a new CONNECT message
|
||||
msg := NewConnectMessage()
|
||||
|
||||
// Set the appropriate parameters
|
||||
msg.SetWillQos(1)
|
||||
msg.SetVersion(4)
|
||||
msg.SetCleanSession(true)
|
||||
msg.SetClientId([]byte("surgemq"))
|
||||
msg.SetKeepAlive(10)
|
||||
msg.SetWillTopic([]byte("will"))
|
||||
msg.SetWillMessage([]byte("send me home"))
|
||||
msg.SetUsername([]byte("surgemq"))
|
||||
msg.SetPassword([]byte("verysecret"))
|
||||
|
||||
// Encode the message and get the io.Reader
|
||||
r, n, err := msg.Encode()
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write n bytes into the connection
|
||||
m, err := io.CopyN(conn, r, int64(n))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Printf("Sent %d bytes of %s message", m, msg.Name())
|
||||
|
||||
To receive a CONNECT message from a connection, we can do:
|
||||
|
||||
// Create a new CONNECT message
|
||||
msg := NewConnectMessage()
|
||||
|
||||
// Decode the message by reading from conn
|
||||
n, err := msg.Decode(conn)
|
||||
|
||||
If you don't know what type of message is coming down the pipe, you can do something like this:
|
||||
|
||||
// Create a buffered IO reader for the connection
|
||||
br := bufio.NewReader(conn)
|
||||
|
||||
// Peek at the first byte, which contains the message type
|
||||
b, err := br.Peek(1)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Extract the type from the first byte
|
||||
t := MessageType(b[0] >> 4)
|
||||
|
||||
// Create a new message
|
||||
msg, err := t.New()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Decode it from the bufio.Reader
|
||||
n, err := msg.Decode(br)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
*/
|
||||
package message
|
||||
@@ -1,248 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var (
|
||||
gPacketId uint64 = 0
|
||||
)
|
||||
|
||||
// Fixed header
|
||||
// - 1 byte for control packet type (bits 7-4) and flags (bits 3-0)
|
||||
// - up to 4 byte for remaining length
|
||||
type header struct {
|
||||
// Header fields
|
||||
//mtype MessageType
|
||||
//flags byte
|
||||
remlen int32
|
||||
|
||||
// mtypeflags is the first byte of the buffer, 4 bits for mtype, 4 bits for flags
|
||||
mtypeflags []byte
|
||||
|
||||
// Some messages need packet ID, 2 byte uint16
|
||||
packetId []byte
|
||||
|
||||
// Points to the decoding buffer
|
||||
dbuf []byte
|
||||
|
||||
// Whether the message has changed since last decode
|
||||
dirty bool
|
||||
}
|
||||
|
||||
// String returns a string representation of the message.
|
||||
func (this header) String() string {
|
||||
return fmt.Sprintf("Type=%q, Flags=%08b, Remaining Length=%d", this.Type().Name(), this.Flags(), this.remlen)
|
||||
}
|
||||
|
||||
// Name returns a string representation of the message type. Examples include
|
||||
// "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of
|
||||
// the message types and cannot be changed.
|
||||
func (this *header) Name() string {
|
||||
return this.Type().Name()
|
||||
}
|
||||
|
||||
// Desc returns a string description of the message type. For example, a
|
||||
// CONNECT message would return "Client request to connect to Server." These
|
||||
// descriptions are statically defined (copied from the MQTT spec) and cannot
|
||||
// be changed.
|
||||
func (this *header) Desc() string {
|
||||
return this.Type().Desc()
|
||||
}
|
||||
|
||||
// Type returns the MessageType of the Message. The retured value should be one
|
||||
// of the constants defined for MessageType.
|
||||
func (this *header) Type() MessageType {
|
||||
//return this.mtype
|
||||
if len(this.mtypeflags) != 1 {
|
||||
this.mtypeflags = make([]byte, 1)
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
return MessageType(this.mtypeflags[0] >> 4)
|
||||
}
|
||||
|
||||
// SetType sets the message type of this message. It also correctly sets the
|
||||
// default flags for the message type. It returns an error if the type is invalid.
|
||||
func (this *header) SetType(mtype MessageType) error {
|
||||
if !mtype.Valid() {
|
||||
return fmt.Errorf("header/SetType: Invalid control packet type %d", mtype)
|
||||
}
|
||||
|
||||
// Notice we don't set the message to be dirty when we are not allocating a new
|
||||
// buffer. In this case, it means the buffer is probably a sub-slice of another
|
||||
// slice. If that's the case, then during encoding we would have copied the whole
|
||||
// backing buffer anyway.
|
||||
if len(this.mtypeflags) != 1 {
|
||||
this.mtypeflags = make([]byte, 1)
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
this.mtypeflags[0] = byte(mtype)<<4 | (mtype.DefaultFlags() & 0xf)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Flags returns the fixed header flags for this message.
|
||||
func (this *header) Flags() byte {
|
||||
//return this.flags
|
||||
return this.mtypeflags[0] & 0x0f
|
||||
}
|
||||
|
||||
// RemainingLength returns the length of the non-fixed-header part of the message.
|
||||
func (this *header) RemainingLength() int32 {
|
||||
return this.remlen
|
||||
}
|
||||
|
||||
// SetRemainingLength sets the length of the non-fixed-header part of the message.
|
||||
// It returns error if the length is greater than 268435455, which is the max
|
||||
// message length as defined by the MQTT spec.
|
||||
func (this *header) SetRemainingLength(remlen int32) error {
|
||||
if remlen > maxRemainingLength || remlen < 0 {
|
||||
return fmt.Errorf("header/SetLength: Remaining length (%d) out of bound (max %d, min 0)", remlen, maxRemainingLength)
|
||||
}
|
||||
|
||||
this.remlen = remlen
|
||||
this.dirty = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (this *header) Len() int {
|
||||
return this.msglen()
|
||||
}
|
||||
|
||||
// PacketId returns the ID of the packet.
|
||||
func (this *header) PacketId() uint16 {
|
||||
if len(this.packetId) == 2 {
|
||||
return binary.BigEndian.Uint16(this.packetId)
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// SetPacketId sets the ID of the packet.
|
||||
func (this *header) SetPacketId(v uint16) {
|
||||
// If setting to 0, nothing to do, move on
|
||||
if v == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// If packetId buffer is not 2 bytes (uint16), then we allocate a new one and
|
||||
// make dirty. Then we encode the packet ID into the buffer.
|
||||
if len(this.packetId) != 2 {
|
||||
this.packetId = make([]byte, 2)
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// Notice we don't set the message to be dirty when we are not allocating a new
|
||||
// buffer. In this case, it means the buffer is probably a sub-slice of another
|
||||
// slice. If that's the case, then during encoding we would have copied the whole
|
||||
// backing buffer anyway.
|
||||
binary.BigEndian.PutUint16(this.packetId, v)
|
||||
}
|
||||
|
||||
func (this *header) encode(dst []byte) (int, error) {
|
||||
ml := this.msglen()
|
||||
|
||||
if len(dst) < ml {
|
||||
return 0, fmt.Errorf("header/Encode: Insufficient buffer size. Expecting %d, got %d.", ml, len(dst))
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
if this.remlen > maxRemainingLength || this.remlen < 0 {
|
||||
return total, fmt.Errorf("header/Encode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength)
|
||||
}
|
||||
|
||||
if !this.Type().Valid() {
|
||||
return total, fmt.Errorf("header/Encode: Invalid message type %d", this.Type())
|
||||
}
|
||||
|
||||
dst[total] = this.mtypeflags[0]
|
||||
total += 1
|
||||
|
||||
n := binary.PutUvarint(dst[total:], uint64(this.remlen))
|
||||
total += n
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// Decode reads from the io.Reader parameter until a full message is decoded, or
|
||||
// when io.Reader returns EOF or error. The first return value is the number of
|
||||
// bytes read from io.Reader. The second is error if Decode encounters any problems.
|
||||
func (this *header) decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
this.dbuf = src
|
||||
|
||||
mtype := this.Type()
|
||||
//mtype := MessageType(0)
|
||||
|
||||
this.mtypeflags = src[total : total+1]
|
||||
//mtype := MessageType(src[total] >> 4)
|
||||
if !this.Type().Valid() {
|
||||
return total, fmt.Errorf("header/Decode: Invalid message type %d.", mtype)
|
||||
}
|
||||
|
||||
if mtype != this.Type() {
|
||||
return total, fmt.Errorf("header/Decode: Invalid message type %d. Expecting %d.", this.Type(), mtype)
|
||||
}
|
||||
|
||||
//this.flags = src[total] & 0x0f
|
||||
if this.Type() != PUBLISH && this.Flags() != this.Type().DefaultFlags() {
|
||||
return total, fmt.Errorf("header/Decode: Invalid message (%d) flags. Expecting %d, got %d", this.Type(), this.Type().DefaultFlags(), this.Flags())
|
||||
}
|
||||
|
||||
if this.Type() == PUBLISH && !ValidQos((this.Flags()>>1)&0x3) {
|
||||
return total, fmt.Errorf("header/Decode: Invalid QoS (%d) for PUBLISH message.", (this.Flags()>>1)&0x3)
|
||||
}
|
||||
|
||||
total++
|
||||
|
||||
remlen, m := binary.Uvarint(src[total:])
|
||||
total += m
|
||||
this.remlen = int32(remlen)
|
||||
|
||||
if this.remlen > maxRemainingLength || remlen < 0 {
|
||||
return total, fmt.Errorf("header/Decode: Remaining length (%d) out of bound (max %d, min 0)", this.remlen, maxRemainingLength)
|
||||
}
|
||||
|
||||
if int(this.remlen) > len(src[total:]) {
|
||||
return total, fmt.Errorf("header/Decode: Remaining length (%d) is greater than remaining buffer (%d)", this.remlen, len(src[total:]))
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *header) msglen() int {
|
||||
// message type and flag byte
|
||||
total := 1
|
||||
|
||||
if this.remlen <= 127 {
|
||||
total += 1
|
||||
} else if this.remlen <= 16383 {
|
||||
total += 2
|
||||
} else if this.remlen <= 2097151 {
|
||||
total += 3
|
||||
} else {
|
||||
total += 4
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
@@ -1,182 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMessageHeaderFields(t *testing.T) {
|
||||
header := &header{}
|
||||
|
||||
header.SetRemainingLength(33)
|
||||
|
||||
require.Equal(t, int32(33), header.RemainingLength())
|
||||
|
||||
err := header.SetRemainingLength(268435456)
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
err = header.SetRemainingLength(-1)
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
err = header.SetType(RESERVED)
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
err = header.SetType(PUBREL)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, PUBREL, header.Type())
|
||||
require.Equal(t, "PUBREL", header.Name())
|
||||
require.Equal(t, 2, int(header.Flags()))
|
||||
}
|
||||
|
||||
// Not enough bytes
|
||||
func TestMessageHeaderDecode(t *testing.T) {
|
||||
buf := []byte{0x6f, 193, 2}
|
||||
header := &header{}
|
||||
|
||||
_, err := header.decode(buf)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// Remaining length too big
|
||||
func TestMessageHeaderDecode2(t *testing.T) {
|
||||
buf := []byte{0x62, 0xff, 0xff, 0xff, 0xff}
|
||||
header := &header{}
|
||||
|
||||
_, err := header.decode(buf)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMessageHeaderDecode3(t *testing.T) {
|
||||
buf := []byte{0x62, 0xff}
|
||||
header := &header{}
|
||||
|
||||
_, err := header.decode(buf)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMessageHeaderDecode4(t *testing.T) {
|
||||
buf := []byte{0x62, 0xff, 0xff, 0xff, 0x7f}
|
||||
header := &header{
|
||||
mtypeflags: []byte{6<<4 | 2},
|
||||
//mtype: 6,
|
||||
//flags: 2,
|
||||
}
|
||||
|
||||
n, err := header.decode(buf)
|
||||
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, maxRemainingLength, header.RemainingLength())
|
||||
}
|
||||
|
||||
func TestMessageHeaderDecode5(t *testing.T) {
|
||||
buf := []byte{0x62, 0xff, 0x7f}
|
||||
header := &header{
|
||||
mtypeflags: []byte{6<<4 | 2},
|
||||
//mtype: 6,
|
||||
//flags: 2,
|
||||
}
|
||||
|
||||
n, err := header.decode(buf)
|
||||
require.Error(t, err)
|
||||
require.Equal(t, 3, n)
|
||||
}
|
||||
|
||||
func TestMessageHeaderEncode1(t *testing.T) {
|
||||
header := &header{}
|
||||
headerBytes := []byte{0x62, 193, 2}
|
||||
|
||||
err := header.SetType(PUBREL)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = header.SetRemainingLength(321)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 3)
|
||||
n, err := header.encode(buf)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 3, n)
|
||||
require.Equal(t, headerBytes, buf)
|
||||
}
|
||||
|
||||
func TestMessageHeaderEncode2(t *testing.T) {
|
||||
header := &header{}
|
||||
|
||||
err := header.SetType(PUBREL)
|
||||
require.NoError(t, err)
|
||||
|
||||
header.remlen = 268435456
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err = header.encode(buf)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestMessageHeaderEncode3(t *testing.T) {
|
||||
header := &header{}
|
||||
headerBytes := []byte{0x62, 0xff, 0xff, 0xff, 0x7f}
|
||||
|
||||
err := header.SetType(PUBREL)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
err = header.SetRemainingLength(maxRemainingLength)
|
||||
|
||||
require.NoError(t, err)
|
||||
|
||||
buf := make([]byte, 5)
|
||||
n, err := header.encode(buf)
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 5, n)
|
||||
require.Equal(t, headerBytes, buf)
|
||||
}
|
||||
|
||||
func TestMessageHeaderEncode4(t *testing.T) {
|
||||
header := &header{
|
||||
mtypeflags: []byte{byte(RESERVED2) << 4},
|
||||
//mtype: 6,
|
||||
//flags: 2,
|
||||
}
|
||||
|
||||
buf := make([]byte, 5)
|
||||
_, err := header.encode(buf)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
/*
|
||||
// This test is to ensure that an empty message is at least 2 bytes long
|
||||
func TestMessageHeaderEncode5(t *testing.T) {
|
||||
msg := NewPingreqMessage()
|
||||
|
||||
dst, n, err := msg.encode()
|
||||
if err != nil {
|
||||
t.Errorf("Error encoding PINGREQ message: %v", err)
|
||||
} else if n != 2 {
|
||||
t.Errorf("Incorrect result. Expecting length of 2 bytes, got %d.", dst.(*bytes.Buffer).Len())
|
||||
}
|
||||
}
|
||||
*/
|
||||
@@ -1,410 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
const (
|
||||
maxLPString uint16 = 65535
|
||||
maxFixedHeaderLength int = 5
|
||||
maxRemainingLength int32 = 268435455 // bytes, or 256 MB
|
||||
)
|
||||
|
||||
const (
|
||||
// QoS 0: At most once delivery
|
||||
// The message is delivered according to the capabilities of the underlying network.
|
||||
// No response is sent by the receiver and no retry is performed by the sender. The
|
||||
// message arrives at the receiver either once or not at all.
|
||||
QosAtMostOnce byte = iota
|
||||
|
||||
// QoS 1: At least once delivery
|
||||
// This quality of service ensures that the message arrives at the receiver at least once.
|
||||
// A QoS 1 PUBLISH Packet has a Packet Identifier in its variable header and is acknowledged
|
||||
// by a PUBACK Packet. Section 2.3.1 provides more information about Packet Identifiers.
|
||||
QosAtLeastOnce
|
||||
|
||||
// QoS 2: Exactly once delivery
|
||||
// This is the highest quality of service, for use when neither loss nor duplication of
|
||||
// messages are acceptable. There is an increased overhead associated with this quality of
|
||||
// service.
|
||||
QosExactlyOnce
|
||||
|
||||
// QosFailure is a return value for a subscription if there's a problem while subscribing
|
||||
// to a specific topic.
|
||||
QosFailure = 0x80
|
||||
)
|
||||
|
||||
// SupportedVersions is a map of the version number (0x3 or 0x4) to the version string,
|
||||
// "MQIsdp" for 0x3, and "MQTT" for 0x4.
|
||||
var SupportedVersions map[byte]string = map[byte]string{
|
||||
0x3: "MQIsdp",
|
||||
0x4: "MQTT",
|
||||
}
|
||||
|
||||
// MessageType is the type representing the MQTT packet types. In the MQTT spec,
|
||||
// MQTT control packet type is represented as a 4-bit unsigned value.
|
||||
type MessageType byte
|
||||
|
||||
// Message is an interface defined for all MQTT message types.
|
||||
type Message interface {
|
||||
// Name returns a string representation of the message type. Examples include
|
||||
// "PUBLISH", "SUBSCRIBE", and others. This is statically defined for each of
|
||||
// the message types and cannot be changed.
|
||||
Name() string
|
||||
|
||||
// Desc returns a string description of the message type. For example, a
|
||||
// CONNECT message would return "Client request to connect to Server." These
|
||||
// descriptions are statically defined (copied from the MQTT spec) and cannot
|
||||
// be changed.
|
||||
Desc() string
|
||||
|
||||
// Type returns the MessageType of the Message. The retured value should be one
|
||||
// of the constants defined for MessageType.
|
||||
Type() MessageType
|
||||
|
||||
// PacketId returns the packet ID of the Message. The retured value is 0 if
|
||||
// there's no packet ID for this message type. Otherwise non-0.
|
||||
PacketId() uint16
|
||||
|
||||
// Encode writes the message bytes into the byte array from the argument. It
|
||||
// returns the number of bytes encoded and whether there's any errors along
|
||||
// the way. If there's any errors, then the byte slice and count should be
|
||||
// considered invalid.
|
||||
Encode([]byte) (int, error)
|
||||
|
||||
// Decode reads the bytes in the byte slice from the argument. It returns the
|
||||
// total number of bytes decoded, and whether there's any errors during the
|
||||
// process. The byte slice MUST NOT be modified during the duration of this
|
||||
// message being available since the byte slice is internally stored for
|
||||
// references.
|
||||
Decode([]byte) (int, error)
|
||||
|
||||
Len() int
|
||||
}
|
||||
|
||||
const (
|
||||
// RESERVED is a reserved value and should be considered an invalid message type
|
||||
RESERVED MessageType = iota
|
||||
|
||||
// CONNECT: Client to Server. Client request to connect to Server.
|
||||
CONNECT
|
||||
|
||||
// CONNACK: Server to Client. Connect acknowledgement.
|
||||
CONNACK
|
||||
|
||||
// PUBLISH: Client to Server, or Server to Client. Publish message.
|
||||
PUBLISH
|
||||
|
||||
// PUBACK: Client to Server, or Server to Client. Publish acknowledgment for
|
||||
// QoS 1 messages.
|
||||
PUBACK
|
||||
|
||||
// PUBACK: Client to Server, or Server to Client. Publish received for QoS 2 messages.
|
||||
// Assured delivery part 1.
|
||||
PUBREC
|
||||
|
||||
// PUBREL: Client to Server, or Server to Client. Publish release for QoS 2 messages.
|
||||
// Assured delivery part 1.
|
||||
PUBREL
|
||||
|
||||
// PUBCOMP: Client to Server, or Server to Client. Publish complete for QoS 2 messages.
|
||||
// Assured delivery part 3.
|
||||
PUBCOMP
|
||||
|
||||
// SUBSCRIBE: Client to Server. Client subscribe request.
|
||||
SUBSCRIBE
|
||||
|
||||
// SUBACK: Server to Client. Subscribe acknowledgement.
|
||||
SUBACK
|
||||
|
||||
// UNSUBSCRIBE: Client to Server. Unsubscribe request.
|
||||
UNSUBSCRIBE
|
||||
|
||||
// UNSUBACK: Server to Client. Unsubscribe acknowlegment.
|
||||
UNSUBACK
|
||||
|
||||
// PINGREQ: Client to Server. PING request.
|
||||
PINGREQ
|
||||
|
||||
// PINGRESP: Server to Client. PING response.
|
||||
PINGRESP
|
||||
|
||||
// DISCONNECT: Client to Server. Client is disconnecting.
|
||||
DISCONNECT
|
||||
|
||||
// RESERVED2 is a reserved value and should be considered an invalid message type.
|
||||
RESERVED2
|
||||
)
|
||||
|
||||
func (this MessageType) String() string {
|
||||
return this.Name()
|
||||
}
|
||||
|
||||
// Name returns the name of the message type. It should correspond to one of the
|
||||
// constant values defined for MessageType. It is statically defined and cannot
|
||||
// be changed.
|
||||
func (this MessageType) Name() string {
|
||||
switch this {
|
||||
case RESERVED:
|
||||
return "RESERVED"
|
||||
case CONNECT:
|
||||
return "CONNECT"
|
||||
case CONNACK:
|
||||
return "CONNACK"
|
||||
case PUBLISH:
|
||||
return "PUBLISH"
|
||||
case PUBACK:
|
||||
return "PUBACK"
|
||||
case PUBREC:
|
||||
return "PUBREC"
|
||||
case PUBREL:
|
||||
return "PUBREL"
|
||||
case PUBCOMP:
|
||||
return "PUBCOMP"
|
||||
case SUBSCRIBE:
|
||||
return "SUBSCRIBE"
|
||||
case SUBACK:
|
||||
return "SUBACK"
|
||||
case UNSUBSCRIBE:
|
||||
return "UNSUBSCRIBE"
|
||||
case UNSUBACK:
|
||||
return "UNSUBACK"
|
||||
case PINGREQ:
|
||||
return "PINGREQ"
|
||||
case PINGRESP:
|
||||
return "PINGRESP"
|
||||
case DISCONNECT:
|
||||
return "DISCONNECT"
|
||||
case RESERVED2:
|
||||
return "RESERVED2"
|
||||
}
|
||||
|
||||
return "UNKNOWN"
|
||||
}
|
||||
|
||||
// Desc returns the description of the message type. It is statically defined (copied
|
||||
// from MQTT spec) and cannot be changed.
|
||||
func (this MessageType) Desc() string {
|
||||
switch this {
|
||||
case RESERVED:
|
||||
return "Reserved"
|
||||
case CONNECT:
|
||||
return "Client request to connect to Server"
|
||||
case CONNACK:
|
||||
return "Connect acknowledgement"
|
||||
case PUBLISH:
|
||||
return "Publish message"
|
||||
case PUBACK:
|
||||
return "Publish acknowledgement"
|
||||
case PUBREC:
|
||||
return "Publish received (assured delivery part 1)"
|
||||
case PUBREL:
|
||||
return "Publish release (assured delivery part 2)"
|
||||
case PUBCOMP:
|
||||
return "Publish complete (assured delivery part 3)"
|
||||
case SUBSCRIBE:
|
||||
return "Client subscribe request"
|
||||
case SUBACK:
|
||||
return "Subscribe acknowledgement"
|
||||
case UNSUBSCRIBE:
|
||||
return "Unsubscribe request"
|
||||
case UNSUBACK:
|
||||
return "Unsubscribe acknowledgement"
|
||||
case PINGREQ:
|
||||
return "PING request"
|
||||
case PINGRESP:
|
||||
return "PING response"
|
||||
case DISCONNECT:
|
||||
return "Client is disconnecting"
|
||||
case RESERVED2:
|
||||
return "Reserved"
|
||||
}
|
||||
|
||||
return "UNKNOWN"
|
||||
}
|
||||
|
||||
// DefaultFlags returns the default flag values for the message type, as defined by
|
||||
// the MQTT spec.
|
||||
func (this MessageType) DefaultFlags() byte {
|
||||
switch this {
|
||||
case RESERVED:
|
||||
return 0
|
||||
case CONNECT:
|
||||
return 0
|
||||
case CONNACK:
|
||||
return 0
|
||||
case PUBLISH:
|
||||
return 0
|
||||
case PUBACK:
|
||||
return 0
|
||||
case PUBREC:
|
||||
return 0
|
||||
case PUBREL:
|
||||
return 2
|
||||
case PUBCOMP:
|
||||
return 0
|
||||
case SUBSCRIBE:
|
||||
return 2
|
||||
case SUBACK:
|
||||
return 0
|
||||
case UNSUBSCRIBE:
|
||||
return 2
|
||||
case UNSUBACK:
|
||||
return 0
|
||||
case PINGREQ:
|
||||
return 0
|
||||
case PINGRESP:
|
||||
return 0
|
||||
case DISCONNECT:
|
||||
return 0
|
||||
case RESERVED2:
|
||||
return 0
|
||||
}
|
||||
|
||||
return 0
|
||||
}
|
||||
|
||||
// New creates a new message based on the message type. It is a shortcut to call
|
||||
// one of the New*Message functions. If an error is returned then the message type
|
||||
// is invalid.
|
||||
func (this MessageType) New() (Message, error) {
|
||||
switch this {
|
||||
case CONNECT:
|
||||
return NewConnectMessage(), nil
|
||||
case CONNACK:
|
||||
return NewConnackMessage(), nil
|
||||
case PUBLISH:
|
||||
return NewPublishMessage(), nil
|
||||
case PUBACK:
|
||||
return NewPubackMessage(), nil
|
||||
case PUBREC:
|
||||
return NewPubrecMessage(), nil
|
||||
case PUBREL:
|
||||
return NewPubrelMessage(), nil
|
||||
case PUBCOMP:
|
||||
return NewPubcompMessage(), nil
|
||||
case SUBSCRIBE:
|
||||
return NewSubscribeMessage(), nil
|
||||
case SUBACK:
|
||||
return NewSubackMessage(), nil
|
||||
case UNSUBSCRIBE:
|
||||
return NewUnsubscribeMessage(), nil
|
||||
case UNSUBACK:
|
||||
return NewUnsubackMessage(), nil
|
||||
case PINGREQ:
|
||||
return NewPingreqMessage(), nil
|
||||
case PINGRESP:
|
||||
return NewPingrespMessage(), nil
|
||||
case DISCONNECT:
|
||||
return NewDisconnectMessage(), nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("msgtype/NewMessage: Invalid message type %d", this)
|
||||
}
|
||||
|
||||
// Valid returns a boolean indicating whether the message type is valid or not.
|
||||
func (this MessageType) Valid() bool {
|
||||
return this > RESERVED && this < RESERVED2
|
||||
}
|
||||
|
||||
// ValidTopic checks the topic, which is a slice of bytes, to see if it's valid. Topic is
|
||||
// considered valid if it's longer than 0 bytes, and doesn't contain any wildcard characters
|
||||
// such as + and #.
|
||||
func ValidTopic(topic []byte) bool {
|
||||
return len(topic) > 0 && bytes.IndexByte(topic, '#') == -1 && bytes.IndexByte(topic, '+') == -1
|
||||
}
|
||||
|
||||
// ValidQos checks the QoS value to see if it's valid. Valid QoS are QosAtMostOnce,
|
||||
// QosAtLeastonce, and QosExactlyOnce.
|
||||
func ValidQos(qos byte) bool {
|
||||
return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce
|
||||
}
|
||||
|
||||
// ValidVersion checks to see if the version is valid. Current supported versions include 0x3 and 0x4.
|
||||
func ValidVersion(v byte) bool {
|
||||
_, ok := SupportedVersions[v]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ValidConnackError checks to see if the error is a Connack Error or not
|
||||
func ValidConnackError(err error) bool {
|
||||
return err == ErrInvalidProtocolVersion || err == ErrIdentifierRejected ||
|
||||
err == ErrServerUnavailable || err == ErrBadUsernameOrPassword || err == ErrNotAuthorized
|
||||
}
|
||||
|
||||
func ValidConnackErrorEx(err error) (bool, ConnackCode) {
|
||||
if err == ErrInvalidProtocolVersion {
|
||||
return true, ErrInvalidProtocolVersion
|
||||
}
|
||||
if err == ErrIdentifierRejected {
|
||||
return true, ErrIdentifierRejected
|
||||
}
|
||||
if err == ErrServerUnavailable {
|
||||
return true, ErrServerUnavailable
|
||||
}
|
||||
if err == ErrBadUsernameOrPassword {
|
||||
return true, ErrBadUsernameOrPassword
|
||||
}
|
||||
if err == ErrNotAuthorized {
|
||||
return true, ErrNotAuthorized
|
||||
}
|
||||
return false, ConnectionAccepted
|
||||
|
||||
}
|
||||
|
||||
// Read length prefixed bytes
|
||||
func readLPBytes(buf []byte) ([]byte, int, error) {
|
||||
if len(buf) < 2 {
|
||||
return nil, 0, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2, len(buf))
|
||||
}
|
||||
|
||||
n, total := 0, 0
|
||||
|
||||
n = int(binary.BigEndian.Uint16(buf))
|
||||
total += 2
|
||||
|
||||
if len(buf) < n {
|
||||
return nil, total, fmt.Errorf("utils/readLPBytes: Insufficient buffer size. Expecting %d, got %d.", n, len(buf))
|
||||
}
|
||||
|
||||
total += n
|
||||
|
||||
return buf[2:total], total, nil
|
||||
}
|
||||
|
||||
// Write length prefixed bytes
|
||||
func writeLPBytes(buf []byte, b []byte) (int, error) {
|
||||
total, n := 0, len(b)
|
||||
|
||||
if n > int(maxLPString) {
|
||||
return 0, fmt.Errorf("utils/writeLPBytes: Length (%d) greater than %d bytes.", n, maxLPString)
|
||||
}
|
||||
|
||||
if len(buf) < 2+n {
|
||||
return 0, fmt.Errorf("utils/writeLPBytes: Insufficient buffer size. Expecting %d, got %d.", 2+n, len(buf))
|
||||
}
|
||||
|
||||
binary.BigEndian.PutUint16(buf, uint16(n))
|
||||
total += 2
|
||||
|
||||
copy(buf[total:], b)
|
||||
total += n
|
||||
|
||||
return total, nil
|
||||
}
|
||||
@@ -1,178 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var (
|
||||
lpstrings []string = []string{
|
||||
"this is a test",
|
||||
"hope it succeeds",
|
||||
"but just in case",
|
||||
"send me your millions",
|
||||
"",
|
||||
}
|
||||
|
||||
lpstringBytes []byte = []byte{
|
||||
0x0, 0xe, 't', 'h', 'i', 's', ' ', 'i', 's', ' ', 'a', ' ', 't', 'e', 's', 't',
|
||||
0x0, 0x10, 'h', 'o', 'p', 'e', ' ', 'i', 't', ' ', 's', 'u', 'c', 'c', 'e', 'e', 'd', 's',
|
||||
0x0, 0x10, 'b', 'u', 't', ' ', 'j', 'u', 's', 't', ' ', 'i', 'n', ' ', 'c', 'a', 's', 'e',
|
||||
0x0, 0x15, 's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'y', 'o', 'u', 'r', ' ', 'm', 'i', 'l', 'l', 'i', 'o', 'n', 's',
|
||||
0x0, 0x0,
|
||||
}
|
||||
|
||||
msgBytes []byte = []byte{
|
||||
byte(CONNECT << 4),
|
||||
60,
|
||||
0, // Length MSB (0)
|
||||
4, // Length LSB (4)
|
||||
'M', 'Q', 'T', 'T',
|
||||
4, // Protocol level 4
|
||||
206, // connect flags 11001110, will QoS = 01
|
||||
0, // Keep Alive MSB (0)
|
||||
10, // Keep Alive LSB (10)
|
||||
0, // Client ID MSB (0)
|
||||
7, // Client ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Will Topic MSB (0)
|
||||
4, // Will Topic LSB (4)
|
||||
'w', 'i', 'l', 'l',
|
||||
0, // Will Message MSB (0)
|
||||
12, // Will Message LSB (12)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
0, // Username ID MSB (0)
|
||||
7, // Username ID LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // Password ID MSB (0)
|
||||
10, // Password ID LSB (10)
|
||||
'v', 'e', 'r', 'y', 's', 'e', 'c', 'r', 'e', 't',
|
||||
}
|
||||
)
|
||||
|
||||
func TestReadLPBytes(t *testing.T) {
|
||||
total := 0
|
||||
|
||||
for _, str := range lpstrings {
|
||||
b, n, err := readLPBytes(lpstringBytes[total:])
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, str, string(b))
|
||||
require.Equal(t, len(str)+2, n)
|
||||
|
||||
total += n
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriteLPBytes(t *testing.T) {
|
||||
total := 0
|
||||
buf := make([]byte, 1000)
|
||||
|
||||
for _, str := range lpstrings {
|
||||
n, err := writeLPBytes(buf[total:], []byte(str))
|
||||
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 2+len(str), n)
|
||||
|
||||
total += n
|
||||
}
|
||||
|
||||
require.Equal(t, lpstringBytes, buf[:total])
|
||||
}
|
||||
|
||||
func TestMessageTypes(t *testing.T) {
|
||||
if CONNECT != 1 ||
|
||||
CONNACK != 2 ||
|
||||
PUBLISH != 3 ||
|
||||
PUBACK != 4 ||
|
||||
PUBREC != 5 ||
|
||||
PUBREL != 6 ||
|
||||
PUBCOMP != 7 ||
|
||||
SUBSCRIBE != 8 ||
|
||||
SUBACK != 9 ||
|
||||
UNSUBSCRIBE != 10 ||
|
||||
UNSUBACK != 11 ||
|
||||
PINGREQ != 12 ||
|
||||
PINGRESP != 13 ||
|
||||
DISCONNECT != 14 {
|
||||
|
||||
t.Errorf("Message types have invalid code")
|
||||
}
|
||||
}
|
||||
|
||||
func TestQosCodes(t *testing.T) {
|
||||
if QosAtMostOnce != 0 || QosAtLeastOnce != 1 || QosExactlyOnce != 2 {
|
||||
t.Errorf("QOS codes invalid")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnackReturnCodes(t *testing.T) {
|
||||
require.Equal(t, ErrInvalidProtocolVersion.Error(), ConnackCode(1).Error(), "Incorrect ConnackCode error value.")
|
||||
|
||||
require.Equal(t, ErrIdentifierRejected.Error(), ConnackCode(2).Error(), "Incorrect ConnackCode error value.")
|
||||
|
||||
require.Equal(t, ErrServerUnavailable.Error(), ConnackCode(3).Error(), "Incorrect ConnackCode error value.")
|
||||
|
||||
require.Equal(t, ErrBadUsernameOrPassword.Error(), ConnackCode(4).Error(), "Incorrect ConnackCode error value.")
|
||||
|
||||
require.Equal(t, ErrNotAuthorized.Error(), ConnackCode(5).Error(), "Incorrect ConnackCode error value.")
|
||||
}
|
||||
|
||||
func TestFixedHeaderFlags(t *testing.T) {
|
||||
type detail struct {
|
||||
name string
|
||||
flags byte
|
||||
}
|
||||
|
||||
details := map[MessageType]detail{
|
||||
RESERVED: detail{"RESERVED", 0},
|
||||
CONNECT: detail{"CONNECT", 0},
|
||||
CONNACK: detail{"CONNACK", 0},
|
||||
PUBLISH: detail{"PUBLISH", 0},
|
||||
PUBACK: detail{"PUBACK", 0},
|
||||
PUBREC: detail{"PUBREC", 0},
|
||||
PUBREL: detail{"PUBREL", 2},
|
||||
PUBCOMP: detail{"PUBCOMP", 0},
|
||||
SUBSCRIBE: detail{"SUBSCRIBE", 2},
|
||||
SUBACK: detail{"SUBACK", 0},
|
||||
UNSUBSCRIBE: detail{"UNSUBSCRIBE", 2},
|
||||
UNSUBACK: detail{"UNSUBACK", 0},
|
||||
PINGREQ: detail{"PINGREQ", 0},
|
||||
PINGRESP: detail{"PINGRESP", 0},
|
||||
DISCONNECT: detail{"DISCONNECT", 0},
|
||||
RESERVED2: detail{"RESERVED2", 0},
|
||||
}
|
||||
|
||||
for m, d := range details {
|
||||
if m.Name() != d.name {
|
||||
t.Errorf("Name mismatch. Expecting %s, got %s.", d.name, m.Name())
|
||||
}
|
||||
|
||||
if m.DefaultFlags() != d.flags {
|
||||
t.Errorf("Flag mismatch for %s. Expecting %d, got %d.", m.Name(), d.flags, m.DefaultFlags())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSupportedVersions(t *testing.T) {
|
||||
for k, v := range SupportedVersions {
|
||||
if k == 0x03 && v != "MQIsdp" {
|
||||
t.Errorf("Protocol version and name mismatch. Expect %s, got %s.", "MQIsdp", v)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,135 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPingreqMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PINGREQ << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewPingreqMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, PINGREQ, msg.Type(), "Error decoding message.")
|
||||
}
|
||||
|
||||
func TestPingreqMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PINGREQ << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewPingreqMessage()
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
func TestPingrespMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PINGRESP << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewPingrespMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, PINGRESP, msg.Type(), "Error decoding message.")
|
||||
}
|
||||
|
||||
func TestPingrespMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PINGRESP << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewPingrespMessage()
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestPingreqDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PINGREQ << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewPingreqMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestPingrespDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PINGRESP << 4),
|
||||
0,
|
||||
}
|
||||
|
||||
msg := NewPingrespMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,34 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
// The PINGREQ Packet is sent from a Client to the Server. It can be used to:
|
||||
// 1. Indicate to the Server that the Client is alive in the absence of any other
|
||||
// Control Packets being sent from the Client to the Server.
|
||||
// 2. Request that the Server responds to confirm that it is alive.
|
||||
// 3. Exercise the network to indicate that the Network Connection is active.
|
||||
type PingreqMessage struct {
|
||||
DisconnectMessage
|
||||
}
|
||||
|
||||
var _ Message = (*PingreqMessage)(nil)
|
||||
|
||||
// NewPingreqMessage creates a new PINGREQ message.
|
||||
func NewPingreqMessage() *PingreqMessage {
|
||||
msg := &PingreqMessage{}
|
||||
msg.SetType(PINGREQ)
|
||||
|
||||
return msg
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
// A PINGRESP Packet is sent by the Server to the Client in response to a PINGREQ
|
||||
// Packet. It indicates that the Server is alive.
|
||||
type PingrespMessage struct {
|
||||
DisconnectMessage
|
||||
}
|
||||
|
||||
var _ Message = (*PingrespMessage)(nil)
|
||||
|
||||
// NewPingrespMessage creates a new PINGRESP message.
|
||||
func NewPingrespMessage() *PingrespMessage {
|
||||
msg := &PingrespMessage{}
|
||||
msg.SetType(PINGRESP)
|
||||
|
||||
return msg
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import "fmt"
|
||||
|
||||
// A PUBACK Packet is the response to a PUBLISH Packet with QoS level 1.
|
||||
type PubackMessage struct {
|
||||
header
|
||||
}
|
||||
|
||||
var _ Message = (*PubackMessage)(nil)
|
||||
|
||||
// NewPubackMessage creates a new PUBACK message.
|
||||
func NewPubackMessage() *PubackMessage {
|
||||
msg := &PubackMessage{}
|
||||
msg.SetType(PUBACK)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (this PubackMessage) String() string {
|
||||
return fmt.Sprintf("%s, Packet ID=%d", this.header, this.packetId)
|
||||
}
|
||||
|
||||
func (this *PubackMessage) Len() int {
|
||||
if !this.dirty {
|
||||
return len(this.dbuf)
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.header.msglen() + ml
|
||||
}
|
||||
|
||||
func (this *PubackMessage) Decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
n, err := this.header.decode(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
||||
this.packetId = src[total : total+2]
|
||||
total += 2
|
||||
|
||||
this.dirty = false
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *PubackMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("puback/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
hl := this.header.msglen()
|
||||
ml := this.msglen()
|
||||
|
||||
if len(dst) < hl+ml {
|
||||
return 0, fmt.Errorf("puback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
||||
}
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
n, err := this.header.encode(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
if copy(dst[total:total+2], this.packetId) != 2 {
|
||||
dst[total], dst[total+1] = 0, 0
|
||||
}
|
||||
total += 2
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *PubackMessage) msglen() int {
|
||||
// packet ID
|
||||
return 2
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPubackMessageFields(t *testing.T) {
|
||||
msg := NewPubackMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
|
||||
require.Equal(t, 100, int(msg.PacketId()))
|
||||
}
|
||||
|
||||
func TestPubackMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBACK << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubackMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, PUBACK, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
||||
}
|
||||
|
||||
// test insufficient bytes
|
||||
func TestPubackMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBACK << 4),
|
||||
2,
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubackMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPubackMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBACK << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubackMessage()
|
||||
msg.SetPacketId(7)
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestPubackDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBACK << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubackMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
// The PUBCOMP Packet is the response to a PUBREL Packet. It is the fourth and
|
||||
// final packet of the QoS 2 protocol exchange.
|
||||
type PubcompMessage struct {
|
||||
PubackMessage
|
||||
}
|
||||
|
||||
var _ Message = (*PubcompMessage)(nil)
|
||||
|
||||
// NewPubcompMessage creates a new PUBCOMP message.
|
||||
func NewPubcompMessage() *PubcompMessage {
|
||||
msg := &PubcompMessage{}
|
||||
msg.SetType(PUBCOMP)
|
||||
|
||||
return msg
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPubcompMessageFields(t *testing.T) {
|
||||
msg := NewPubcompMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
|
||||
require.Equal(t, 100, int(msg.PacketId()))
|
||||
}
|
||||
|
||||
func TestPubcompMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBCOMP << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubcompMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, PUBCOMP, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
||||
}
|
||||
|
||||
// test insufficient bytes
|
||||
func TestPubcompMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBCOMP << 4),
|
||||
2,
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubcompMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPubcompMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBCOMP << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubcompMessage()
|
||||
msg.SetPacketId(7)
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestPubcompDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBCOMP << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubcompMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,250 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// A PUBLISH Control Packet is sent from a Client to a Server or from Server to a Client
|
||||
// to transport an Application Message.
|
||||
type PublishMessage struct {
|
||||
header
|
||||
|
||||
topic []byte
|
||||
payload []byte
|
||||
}
|
||||
|
||||
var _ Message = (*PublishMessage)(nil)
|
||||
|
||||
// NewPublishMessage creates a new PUBLISH message.
|
||||
func NewPublishMessage() *PublishMessage {
|
||||
msg := &PublishMessage{}
|
||||
msg.SetType(PUBLISH)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (this PublishMessage) String() string {
|
||||
return fmt.Sprintf("%s, Topic=%q, Packet ID=%d, QoS=%d, Retained=%t, Dup=%t, Payload=%v",
|
||||
this.header, this.topic, this.packetId, this.QoS(), this.Retain(), this.Dup(), this.payload)
|
||||
}
|
||||
|
||||
// Dup returns the value specifying the duplicate delivery of a PUBLISH Control Packet.
|
||||
// If the DUP flag is set to 0, it indicates that this is the first occasion that the
|
||||
// Client or Server has attempted to send this MQTT PUBLISH Packet. If the DUP flag is
|
||||
// set to 1, it indicates that this might be re-delivery of an earlier attempt to send
|
||||
// the Packet.
|
||||
func (this *PublishMessage) Dup() bool {
|
||||
return ((this.Flags() >> 3) & 0x1) == 1
|
||||
}
|
||||
|
||||
// SetDup sets the value specifying the duplicate delivery of a PUBLISH Control Packet.
|
||||
func (this *PublishMessage) SetDup(v bool) {
|
||||
if v {
|
||||
this.mtypeflags[0] |= 0x8 // 00001000
|
||||
} else {
|
||||
this.mtypeflags[0] &= 247 // 11110111
|
||||
}
|
||||
}
|
||||
|
||||
// Retain returns the value of the RETAIN flag. This flag is only used on the PUBLISH
|
||||
// Packet. If the RETAIN flag is set to 1, in a PUBLISH Packet sent by a Client to a
|
||||
// Server, the Server MUST store the Application Message and its QoS, so that it can be
|
||||
// delivered to future subscribers whose subscriptions match its topic name.
|
||||
func (this *PublishMessage) Retain() bool {
|
||||
return (this.Flags() & 0x1) == 1
|
||||
}
|
||||
|
||||
// SetRetain sets the value of the RETAIN flag.
|
||||
func (this *PublishMessage) SetRetain(v bool) {
|
||||
if v {
|
||||
this.mtypeflags[0] |= 0x1 // 00000001
|
||||
} else {
|
||||
this.mtypeflags[0] &= 254 // 11111110
|
||||
}
|
||||
}
|
||||
|
||||
// QoS returns the field that indicates the level of assurance for delivery of an
|
||||
// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce.
|
||||
func (this *PublishMessage) QoS() byte {
|
||||
return (this.Flags() >> 1) & 0x3
|
||||
}
|
||||
|
||||
// SetQoS sets the field that indicates the level of assurance for delivery of an
|
||||
// Application Message. The values are QosAtMostOnce, QosAtLeastOnce and QosExactlyOnce.
|
||||
// An error is returned if the value is not one of these.
|
||||
func (this *PublishMessage) SetQoS(v byte) error {
|
||||
if v != 0x0 && v != 0x1 && v != 0x2 {
|
||||
return fmt.Errorf("publish/SetQoS: Invalid QoS %d.", v)
|
||||
}
|
||||
|
||||
this.mtypeflags[0] = (this.mtypeflags[0] & 249) | (v << 1) // 249 = 11111001
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Topic returns the the topic name that identifies the information channel to which
|
||||
// payload data is published.
|
||||
func (this *PublishMessage) Topic() []byte {
|
||||
return this.topic
|
||||
}
|
||||
|
||||
// SetTopic sets the the topic name that identifies the information channel to which
|
||||
// payload data is published. An error is returned if ValidTopic() is falbase.
|
||||
func (this *PublishMessage) SetTopic(v []byte) error {
|
||||
if !ValidTopic(v) {
|
||||
return fmt.Errorf("publish/SetTopic: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(v))
|
||||
}
|
||||
|
||||
this.topic = v
|
||||
this.dirty = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Payload returns the application message that's part of the PUBLISH message.
|
||||
func (this *PublishMessage) Payload() []byte {
|
||||
return this.payload
|
||||
}
|
||||
|
||||
// SetPayload sets the application message that's part of the PUBLISH message.
|
||||
func (this *PublishMessage) SetPayload(v []byte) {
|
||||
this.payload = v
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
func (this *PublishMessage) Len() int {
|
||||
if !this.dirty {
|
||||
return len(this.dbuf)
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.header.msglen() + ml
|
||||
}
|
||||
|
||||
func (this *PublishMessage) Decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
hn, err := this.header.decode(src[total:])
|
||||
total += hn
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
n := 0
|
||||
|
||||
this.topic, n, err = readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
if !ValidTopic(this.topic) {
|
||||
return total, fmt.Errorf("publish/Decode: Invalid topic name (%s). Must not be empty or contain wildcard characters", string(this.topic))
|
||||
}
|
||||
|
||||
// The packet identifier field is only present in the PUBLISH packets where the
|
||||
// QoS level is 1 or 2
|
||||
if this.QoS() != 0 {
|
||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
||||
this.packetId = src[total : total+2]
|
||||
total += 2
|
||||
}
|
||||
|
||||
l := int(this.remlen) - (total - hn)
|
||||
this.payload = src[total : total+l]
|
||||
total += len(this.payload)
|
||||
|
||||
this.dirty = false
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *PublishMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("publish/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
if len(this.topic) == 0 {
|
||||
return 0, fmt.Errorf("publish/Encode: Topic name is empty.")
|
||||
}
|
||||
|
||||
if len(this.payload) == 0 {
|
||||
return 0, fmt.Errorf("publish/Encode: Payload is empty.")
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
hl := this.header.msglen()
|
||||
|
||||
if len(dst) < hl+ml {
|
||||
return 0, fmt.Errorf("publish/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
n, err := this.header.encode(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
n, err = writeLPBytes(dst[total:], this.topic)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
// The packet identifier field is only present in the PUBLISH packets where the QoS level is 1 or 2
|
||||
if this.QoS() != 0 {
|
||||
if this.PacketId() == 0 {
|
||||
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
|
||||
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
|
||||
}
|
||||
|
||||
n = copy(dst[total:], this.packetId)
|
||||
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
|
||||
total += n
|
||||
}
|
||||
|
||||
copy(dst[total:], this.payload)
|
||||
total += len(this.payload)
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *PublishMessage) msglen() int {
|
||||
total := 2 + len(this.topic) + len(this.payload)
|
||||
if this.QoS() != 0 {
|
||||
total += 2
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
@@ -1,281 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPublishMessageHeaderFields(t *testing.T) {
|
||||
msg := NewPublishMessage()
|
||||
msg.mtypeflags[0] |= 11
|
||||
|
||||
require.True(t, msg.Dup(), "Incorrect DUP flag.")
|
||||
require.True(t, msg.Retain(), "Incorrect RETAIN flag.")
|
||||
require.Equal(t, 1, int(msg.QoS()), "Incorrect QoS.")
|
||||
|
||||
msg.SetDup(false)
|
||||
|
||||
require.False(t, msg.Dup(), "Incorrect DUP flag.")
|
||||
|
||||
msg.SetRetain(false)
|
||||
|
||||
require.False(t, msg.Retain(), "Incorrect RETAIN flag.")
|
||||
|
||||
err := msg.SetQoS(2)
|
||||
|
||||
require.NoError(t, err, "Error setting QoS.")
|
||||
require.Equal(t, 2, int(msg.QoS()), "Incorrect QoS.")
|
||||
|
||||
err = msg.SetQoS(3)
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
err = msg.SetQoS(0)
|
||||
|
||||
require.NoError(t, err, "Error setting QoS.")
|
||||
require.Equal(t, 0, int(msg.QoS()), "Incorrect QoS.")
|
||||
|
||||
msg.SetDup(true)
|
||||
|
||||
require.True(t, msg.Dup(), "Incorrect DUP flag.")
|
||||
|
||||
msg.SetRetain(true)
|
||||
|
||||
require.True(t, msg.Retain(), "Incorrect RETAIN flag.")
|
||||
}
|
||||
|
||||
func TestPublishMessageFields(t *testing.T) {
|
||||
msg := NewPublishMessage()
|
||||
|
||||
msg.SetTopic([]byte("coolstuff"))
|
||||
|
||||
require.Equal(t, "coolstuff", string(msg.Topic()), "Error setting message topic.")
|
||||
|
||||
err := msg.SetTopic([]byte("coolstuff/#"))
|
||||
|
||||
require.Error(t, err)
|
||||
|
||||
msg.SetPacketId(100)
|
||||
|
||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting acket ID.")
|
||||
|
||||
msg.SetPayload([]byte("this is a payload to be sent"))
|
||||
|
||||
require.Equal(t, []byte("this is a payload to be sent"), msg.Payload(), "Error setting payload.")
|
||||
}
|
||||
|
||||
func TestPublishMessageDecode1(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBLISH<<4) | 2,
|
||||
23,
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
}
|
||||
|
||||
msg := NewPublishMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
||||
require.Equal(t, "surgemq", string(msg.Topic()), "Error deocding topic name.")
|
||||
require.Equal(t, []byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'}, msg.Payload(), "Error deocding payload.")
|
||||
}
|
||||
|
||||
// test insufficient bytes
|
||||
func TestPublishMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBLISH<<4) | 2,
|
||||
26,
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
}
|
||||
|
||||
msg := NewPublishMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// test qos = 0 and no client id
|
||||
func TestPublishMessageDecode3(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBLISH << 4),
|
||||
21,
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
}
|
||||
|
||||
msg := NewPublishMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
}
|
||||
|
||||
func TestPublishMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBLISH<<4) | 2,
|
||||
23,
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
}
|
||||
|
||||
msg := NewPublishMessage()
|
||||
msg.SetTopic([]byte("surgemq"))
|
||||
msg.SetQoS(1)
|
||||
msg.SetPacketId(7)
|
||||
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test empty topic name
|
||||
func TestPublishMessageEncode2(t *testing.T) {
|
||||
msg := NewPublishMessage()
|
||||
msg.SetTopic([]byte(""))
|
||||
msg.SetPacketId(7)
|
||||
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
|
||||
|
||||
dst := make([]byte, 100)
|
||||
_, err := msg.Encode(dst)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
// test encoding qos = 0 and no packet id
|
||||
func TestPublishMessageEncode3(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBLISH << 4),
|
||||
21,
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
}
|
||||
|
||||
msg := NewPublishMessage()
|
||||
msg.SetTopic([]byte("surgemq"))
|
||||
msg.SetQoS(0)
|
||||
msg.SetPayload([]byte{'s', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e'})
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test large message
|
||||
func TestPublishMessageEncode4(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBLISH << 4),
|
||||
137,
|
||||
8,
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
}
|
||||
|
||||
payload := make([]byte, 1024)
|
||||
msgBytes = append(msgBytes, payload...)
|
||||
|
||||
msg := NewPublishMessage()
|
||||
msg.SetTopic([]byte("surgemq"))
|
||||
msg.SetQoS(0)
|
||||
msg.SetPayload(payload)
|
||||
|
||||
require.Equal(t, len(msgBytes), msg.Len())
|
||||
|
||||
dst := make([]byte, 1100)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test from github issue #2, @mrdg
|
||||
func TestPublishDecodeEncodeEquiv2(t *testing.T) {
|
||||
msgBytes := []byte{50, 18, 0, 9, 103, 114, 101, 101, 116, 105, 110, 103, 115, 0, 1, 72, 101, 108, 108, 111}
|
||||
|
||||
msg := NewPublishMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestPublishDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBLISH<<4) | 2,
|
||||
23,
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
's', 'e', 'n', 'd', ' ', 'm', 'e', ' ', 'h', 'o', 'm', 'e',
|
||||
}
|
||||
|
||||
msg := NewPublishMessage()
|
||||
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
type PubrecMessage struct {
|
||||
PubackMessage
|
||||
}
|
||||
|
||||
// A PUBREC Packet is the response to a PUBLISH Packet with QoS 2. It is the second
|
||||
// packet of the QoS 2 protocol exchange.
|
||||
var _ Message = (*PubrecMessage)(nil)
|
||||
|
||||
// NewPubrecMessage creates a new PUBREC message.
|
||||
func NewPubrecMessage() *PubrecMessage {
|
||||
msg := &PubrecMessage{}
|
||||
msg.SetType(PUBREC)
|
||||
|
||||
return msg
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPubrecMessageFields(t *testing.T) {
|
||||
msg := NewPubrecMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
|
||||
require.Equal(t, 100, int(msg.PacketId()))
|
||||
}
|
||||
|
||||
func TestPubrecMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREC << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrecMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, PUBREC, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
||||
}
|
||||
|
||||
// test insufficient bytes
|
||||
func TestPubrecMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREC << 4),
|
||||
2,
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrecMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPubrecMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREC << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrecMessage()
|
||||
msg.SetPacketId(7)
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestPubrecDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREC << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrecMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
// A PUBREL Packet is the response to a PUBREC Packet. It is the third packet of the
|
||||
// QoS 2 protocol exchange.
|
||||
type PubrelMessage struct {
|
||||
PubackMessage
|
||||
}
|
||||
|
||||
var _ Message = (*PubrelMessage)(nil)
|
||||
|
||||
// NewPubrelMessage creates a new PUBREL message.
|
||||
func NewPubrelMessage() *PubrelMessage {
|
||||
msg := &PubrelMessage{}
|
||||
msg.SetType(PUBREL)
|
||||
|
||||
return msg
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPubrelMessageFields(t *testing.T) {
|
||||
msg := NewPubrelMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
|
||||
require.Equal(t, 100, int(msg.PacketId()))
|
||||
}
|
||||
|
||||
func TestPubrelMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREL<<4) | 2,
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrelMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, PUBREL, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
||||
}
|
||||
|
||||
// test insufficient bytes
|
||||
func TestPubrelMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREL<<4) | 2,
|
||||
2,
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrelMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestPubrelMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREL<<4) | 2,
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrelMessage()
|
||||
msg.SetPacketId(7)
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestPubrelDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(PUBREL<<4) | 2,
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewPubrelMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,160 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import "fmt"
|
||||
|
||||
// A SUBACK Packet is sent by the Server to the Client to confirm receipt and processing
|
||||
// of a SUBSCRIBE Packet.
|
||||
//
|
||||
// A SUBACK Packet contains a list of return codes, that specify the maximum QoS level
|
||||
// that was granted in each Subscription that was requested by the SUBSCRIBE.
|
||||
type SubackMessage struct {
|
||||
header
|
||||
|
||||
returnCodes []byte
|
||||
}
|
||||
|
||||
var _ Message = (*SubackMessage)(nil)
|
||||
|
||||
// NewSubackMessage creates a new SUBACK message.
|
||||
func NewSubackMessage() *SubackMessage {
|
||||
msg := &SubackMessage{}
|
||||
msg.SetType(SUBACK)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
// String returns a string representation of the message.
|
||||
func (this SubackMessage) String() string {
|
||||
return fmt.Sprintf("%s, Packet ID=%d, Return Codes=%v", this.header, this.PacketId(), this.returnCodes)
|
||||
}
|
||||
|
||||
// ReturnCodes returns the list of QoS returns from the subscriptions sent in the SUBSCRIBE message.
|
||||
func (this *SubackMessage) ReturnCodes() []byte {
|
||||
return this.returnCodes
|
||||
}
|
||||
|
||||
// AddReturnCodes sets the list of QoS returns from the subscriptions sent in the SUBSCRIBE message.
|
||||
// An error is returned if any of the QoS values are not valid.
|
||||
func (this *SubackMessage) AddReturnCodes(ret []byte) error {
|
||||
for _, c := range ret {
|
||||
if c != QosAtMostOnce && c != QosAtLeastOnce && c != QosExactlyOnce && c != QosFailure {
|
||||
return fmt.Errorf("suback/AddReturnCode: Invalid return code %d. Must be 0, 1, 2, 0x80.", c)
|
||||
}
|
||||
|
||||
this.returnCodes = append(this.returnCodes, c)
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddReturnCode adds a single QoS return value.
|
||||
func (this *SubackMessage) AddReturnCode(ret byte) error {
|
||||
return this.AddReturnCodes([]byte{ret})
|
||||
}
|
||||
|
||||
func (this *SubackMessage) Len() int {
|
||||
if !this.dirty {
|
||||
return len(this.dbuf)
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.header.msglen() + ml
|
||||
}
|
||||
|
||||
func (this *SubackMessage) Decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
hn, err := this.header.decode(src[total:])
|
||||
total += hn
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
||||
this.packetId = src[total : total+2]
|
||||
total += 2
|
||||
|
||||
l := int(this.remlen) - (total - hn)
|
||||
this.returnCodes = src[total : total+l]
|
||||
total += len(this.returnCodes)
|
||||
|
||||
for i, code := range this.returnCodes {
|
||||
if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 {
|
||||
return total, fmt.Errorf("suback/Decode: Invalid return code %d for topic %d", code, i)
|
||||
}
|
||||
}
|
||||
|
||||
this.dirty = false
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *SubackMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("suback/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
for i, code := range this.returnCodes {
|
||||
if code != 0x00 && code != 0x01 && code != 0x02 && code != 0x80 {
|
||||
return 0, fmt.Errorf("suback/Encode: Invalid return code %d for topic %d", code, i)
|
||||
}
|
||||
}
|
||||
|
||||
hl := this.header.msglen()
|
||||
ml := this.msglen()
|
||||
|
||||
if len(dst) < hl+ml {
|
||||
return 0, fmt.Errorf("suback/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
||||
}
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
n, err := this.header.encode(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
if copy(dst[total:total+2], this.packetId) != 2 {
|
||||
dst[total], dst[total+1] = 0, 0
|
||||
}
|
||||
total += 2
|
||||
|
||||
copy(dst[total:], this.returnCodes)
|
||||
total += len(this.returnCodes)
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *SubackMessage) msglen() int {
|
||||
return 2 + len(this.returnCodes)
|
||||
}
|
||||
@@ -1,134 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSubackMessageFields(t *testing.T) {
|
||||
msg := NewSubackMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
|
||||
|
||||
msg.AddReturnCode(1)
|
||||
require.Equal(t, 1, len(msg.ReturnCodes()), "Error adding return code.")
|
||||
|
||||
err := msg.AddReturnCode(0x90)
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubackMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBACK << 4),
|
||||
6,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // return code 1
|
||||
1, // return code 2
|
||||
2, // return code 3
|
||||
0x80, // return code 4
|
||||
}
|
||||
|
||||
msg := NewSubackMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, SUBACK, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 4, len(msg.ReturnCodes()), "Error adding return code.")
|
||||
}
|
||||
|
||||
// test with wrong return code
|
||||
func TestSubackMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBACK << 4),
|
||||
6,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // return code 1
|
||||
1, // return code 2
|
||||
2, // return code 3
|
||||
0x81, // return code 4
|
||||
}
|
||||
|
||||
msg := NewSubackMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubackMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBACK << 4),
|
||||
6,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // return code 1
|
||||
1, // return code 2
|
||||
2, // return code 3
|
||||
0x80, // return code 4
|
||||
}
|
||||
|
||||
msg := NewSubackMessage()
|
||||
msg.SetPacketId(7)
|
||||
msg.AddReturnCode(0)
|
||||
msg.AddReturnCode(1)
|
||||
msg.AddReturnCode(2)
|
||||
msg.AddReturnCode(0x80)
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestSubackDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBACK << 4),
|
||||
6,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // return code 1
|
||||
1, // return code 2
|
||||
2, // return code 3
|
||||
0x80, // return code 4
|
||||
}
|
||||
|
||||
msg := NewSubackMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// The SUBSCRIBE Packet is sent from the Client to the Server to create one or more
|
||||
// Subscriptions. Each Subscription registers a Client’s interest in one or more
|
||||
// Topics. The Server sends PUBLISH Packets to the Client in order to forward
|
||||
// Application Messages that were published to Topics that match these Subscriptions.
|
||||
// The SUBSCRIBE Packet also specifies (for each Subscription) the maximum QoS with
|
||||
// which the Server can send Application Messages to the Client.
|
||||
type SubscribeMessage struct {
|
||||
header
|
||||
|
||||
topics [][]byte
|
||||
qos []byte
|
||||
}
|
||||
|
||||
var _ Message = (*SubscribeMessage)(nil)
|
||||
|
||||
// NewSubscribeMessage creates a new SUBSCRIBE message.
|
||||
func NewSubscribeMessage() *SubscribeMessage {
|
||||
msg := &SubscribeMessage{}
|
||||
msg.SetType(SUBSCRIBE)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (this SubscribeMessage) String() string {
|
||||
msgstr := fmt.Sprintf("%s, Packet ID=%d", this.header, this.PacketId())
|
||||
|
||||
for i, t := range this.topics {
|
||||
msgstr = fmt.Sprintf("%s, Topic[%d]=%q/%d", msgstr, i, string(t), this.qos[i])
|
||||
}
|
||||
|
||||
return msgstr
|
||||
}
|
||||
|
||||
// Topics returns a list of topics sent by the Client.
|
||||
func (this *SubscribeMessage) Topics() [][]byte {
|
||||
return this.topics
|
||||
}
|
||||
|
||||
// AddTopic adds a single topic to the message, along with the corresponding QoS.
|
||||
// An error is returned if QoS is invalid.
|
||||
func (this *SubscribeMessage) AddTopic(topic []byte, qos byte) error {
|
||||
if !ValidQos(qos) {
|
||||
return fmt.Errorf("Invalid QoS %d", qos)
|
||||
}
|
||||
|
||||
var i int
|
||||
var t []byte
|
||||
var found bool
|
||||
|
||||
for i, t = range this.topics {
|
||||
if bytes.Equal(t, topic) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
this.qos[i] = qos
|
||||
return nil
|
||||
}
|
||||
|
||||
this.topics = append(this.topics, topic)
|
||||
this.qos = append(this.qos, qos)
|
||||
this.dirty = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveTopic removes a single topic from the list of existing ones in the message.
|
||||
// If topic does not exist it just does nothing.
|
||||
func (this *SubscribeMessage) RemoveTopic(topic []byte) {
|
||||
var i int
|
||||
var t []byte
|
||||
var found bool
|
||||
|
||||
for i, t = range this.topics {
|
||||
if bytes.Equal(t, topic) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
this.topics = append(this.topics[:i], this.topics[i+1:]...)
|
||||
this.qos = append(this.qos[:i], this.qos[i+1:]...)
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// TopicExists checks to see if a topic exists in the list.
|
||||
func (this *SubscribeMessage) TopicExists(topic []byte) bool {
|
||||
for _, t := range this.topics {
|
||||
if bytes.Equal(t, topic) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// TopicQos returns the QoS level of a topic. If topic does not exist, QosFailure
|
||||
// is returned.
|
||||
func (this *SubscribeMessage) TopicQos(topic []byte) byte {
|
||||
for i, t := range this.topics {
|
||||
if bytes.Equal(t, topic) {
|
||||
return this.qos[i]
|
||||
}
|
||||
}
|
||||
|
||||
return QosFailure
|
||||
}
|
||||
|
||||
// Qos returns the list of QoS current in the message.
|
||||
func (this *SubscribeMessage) Qos() []byte {
|
||||
return this.qos
|
||||
}
|
||||
|
||||
func (this *SubscribeMessage) Len() int {
|
||||
if !this.dirty {
|
||||
return len(this.dbuf)
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.header.msglen() + ml
|
||||
}
|
||||
|
||||
func (this *SubscribeMessage) Decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
hn, err := this.header.decode(src[total:])
|
||||
total += hn
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
||||
this.packetId = src[total : total+2]
|
||||
total += 2
|
||||
|
||||
remlen := int(this.remlen) - (total - hn)
|
||||
for remlen > 0 {
|
||||
t, n, err := readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
this.topics = append(this.topics, t)
|
||||
|
||||
this.qos = append(this.qos, src[total])
|
||||
total++
|
||||
|
||||
remlen = remlen - n - 1
|
||||
}
|
||||
|
||||
if len(this.topics) == 0 {
|
||||
return 0, fmt.Errorf("subscribe/Decode: Empty topic list")
|
||||
}
|
||||
|
||||
this.dirty = false
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *SubscribeMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("subscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
hl := this.header.msglen()
|
||||
ml := this.msglen()
|
||||
|
||||
if len(dst) < hl+ml {
|
||||
return 0, fmt.Errorf("subscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
||||
}
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
n, err := this.header.encode(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
if this.PacketId() == 0 {
|
||||
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
|
||||
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
|
||||
}
|
||||
|
||||
n = copy(dst[total:], this.packetId)
|
||||
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
|
||||
total += n
|
||||
|
||||
for i, t := range this.topics {
|
||||
n, err := writeLPBytes(dst[total:], t)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
dst[total] = this.qos[i]
|
||||
total++
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *SubscribeMessage) msglen() int {
|
||||
// packet ID
|
||||
total := 2
|
||||
|
||||
for _, t := range this.topics {
|
||||
total += 2 + len(t) + 1
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
@@ -1,161 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestSubscribeMessageFields(t *testing.T) {
|
||||
msg := NewSubscribeMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
|
||||
|
||||
msg.AddTopic([]byte("/a/b/#/c"), 1)
|
||||
require.Equal(t, 1, len(msg.Topics()), "Error adding topic.")
|
||||
|
||||
require.False(t, msg.TopicExists([]byte("a/b")), "Topic should not exist.")
|
||||
|
||||
msg.RemoveTopic([]byte("/a/b/#/c"))
|
||||
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
|
||||
}
|
||||
|
||||
func TestSubscribeMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBSCRIBE<<4) | 2,
|
||||
36,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // QoS
|
||||
0, // topic name MSB (0)
|
||||
8, // topic name LSB (8)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
||||
1, // QoS
|
||||
0, // topic name MSB (0)
|
||||
10, // topic name LSB (10)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
||||
2, // QoS
|
||||
}
|
||||
|
||||
msg := NewSubscribeMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, SUBSCRIBE, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 3, len(msg.Topics()), "Error decoding topics.")
|
||||
require.True(t, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.")
|
||||
require.Equal(t, 0, int(msg.TopicQos([]byte("surgemq"))), "Incorrect topic qos.")
|
||||
require.True(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.")
|
||||
require.Equal(t, 1, int(msg.TopicQos([]byte("/a/b/#/c"))), "Incorrect topic qos.")
|
||||
require.True(t, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.")
|
||||
require.Equal(t, 2, int(msg.TopicQos([]byte("/a/b/#/cdd"))), "Incorrect topic qos.")
|
||||
}
|
||||
|
||||
// test empty topic list
|
||||
func TestSubscribeMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBSCRIBE<<4) | 2,
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewSubscribeMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestSubscribeMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBSCRIBE<<4) | 2,
|
||||
36,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // QoS
|
||||
0, // topic name MSB (0)
|
||||
8, // topic name LSB (8)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
||||
1, // QoS
|
||||
0, // topic name MSB (0)
|
||||
10, // topic name LSB (10)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
||||
2, // QoS
|
||||
}
|
||||
|
||||
msg := NewSubscribeMessage()
|
||||
msg.SetPacketId(7)
|
||||
msg.AddTopic([]byte("surgemq"), 0)
|
||||
msg.AddTopic([]byte("/a/b/#/c"), 1)
|
||||
msg.AddTopic([]byte("/a/b/#/cdd"), 2)
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestSubscribeDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(SUBSCRIBE<<4) | 2,
|
||||
36,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // QoS
|
||||
0, // topic name MSB (0)
|
||||
8, // topic name LSB (8)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
||||
1, // QoS
|
||||
0, // topic name MSB (0)
|
||||
10, // topic name LSB (10)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
||||
2, // QoS
|
||||
}
|
||||
|
||||
msg := NewSubscribeMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
// The UNSUBACK Packet is sent by the Server to the Client to confirm receipt of an
|
||||
// UNSUBSCRIBE Packet.
|
||||
type UnsubackMessage struct {
|
||||
PubackMessage
|
||||
}
|
||||
|
||||
var _ Message = (*UnsubackMessage)(nil)
|
||||
|
||||
// NewUnsubackMessage creates a new UNSUBACK message.
|
||||
func NewUnsubackMessage() *UnsubackMessage {
|
||||
msg := &UnsubackMessage{}
|
||||
msg.SetType(UNSUBACK)
|
||||
|
||||
return msg
|
||||
}
|
||||
@@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUnsubackMessageFields(t *testing.T) {
|
||||
msg := NewUnsubackMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
|
||||
require.Equal(t, 100, int(msg.PacketId()))
|
||||
}
|
||||
|
||||
func TestUnsubackMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBACK << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewUnsubackMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, UNSUBACK, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 7, int(msg.PacketId()), "Error decoding message.")
|
||||
}
|
||||
|
||||
// test insufficient bytes
|
||||
func TestUnsubackMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBACK << 4),
|
||||
2,
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewUnsubackMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUnsubackMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBACK << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewUnsubackMessage()
|
||||
msg.SetPacketId(7)
|
||||
|
||||
dst := make([]byte, 10)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestUnsubackDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBACK << 4),
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewUnsubackMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
@@ -1,210 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
// An UNSUBSCRIBE Packet is sent by the Client to the Server, to unsubscribe from topics.
|
||||
type UnsubscribeMessage struct {
|
||||
header
|
||||
|
||||
topics [][]byte
|
||||
}
|
||||
|
||||
var _ Message = (*UnsubscribeMessage)(nil)
|
||||
|
||||
// NewUnsubscribeMessage creates a new UNSUBSCRIBE message.
|
||||
func NewUnsubscribeMessage() *UnsubscribeMessage {
|
||||
msg := &UnsubscribeMessage{}
|
||||
msg.SetType(UNSUBSCRIBE)
|
||||
|
||||
return msg
|
||||
}
|
||||
|
||||
func (this UnsubscribeMessage) String() string {
|
||||
msgstr := fmt.Sprintf("%s", this.header)
|
||||
|
||||
for i, t := range this.topics {
|
||||
msgstr = fmt.Sprintf("%s, Topic%d=%s", msgstr, i, string(t))
|
||||
}
|
||||
|
||||
return msgstr
|
||||
}
|
||||
|
||||
// Topics returns a list of topics sent by the Client.
|
||||
func (this *UnsubscribeMessage) Topics() [][]byte {
|
||||
return this.topics
|
||||
}
|
||||
|
||||
// AddTopic adds a single topic to the message.
|
||||
func (this *UnsubscribeMessage) AddTopic(topic []byte) {
|
||||
if this.TopicExists(topic) {
|
||||
return
|
||||
}
|
||||
|
||||
this.topics = append(this.topics, topic)
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// RemoveTopic removes a single topic from the list of existing ones in the message.
|
||||
// If topic does not exist it just does nothing.
|
||||
func (this *UnsubscribeMessage) RemoveTopic(topic []byte) {
|
||||
var i int
|
||||
var t []byte
|
||||
var found bool
|
||||
|
||||
for i, t = range this.topics {
|
||||
if bytes.Equal(t, topic) {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if found {
|
||||
this.topics = append(this.topics[:i], this.topics[i+1:]...)
|
||||
}
|
||||
|
||||
this.dirty = true
|
||||
}
|
||||
|
||||
// TopicExists checks to see if a topic exists in the list.
|
||||
func (this *UnsubscribeMessage) TopicExists(topic []byte) bool {
|
||||
for _, t := range this.topics {
|
||||
if bytes.Equal(t, topic) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func (this *UnsubscribeMessage) Len() int {
|
||||
if !this.dirty {
|
||||
return len(this.dbuf)
|
||||
}
|
||||
|
||||
ml := this.msglen()
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
return this.header.msglen() + ml
|
||||
}
|
||||
|
||||
// Decode reads from the io.Reader parameter until a full message is decoded, or
|
||||
// when io.Reader returns EOF or error. The first return value is the number of
|
||||
// bytes read from io.Reader. The second is error if Decode encounters any problems.
|
||||
func (this *UnsubscribeMessage) Decode(src []byte) (int, error) {
|
||||
total := 0
|
||||
|
||||
hn, err := this.header.decode(src[total:])
|
||||
total += hn
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
//this.packetId = binary.BigEndian.Uint16(src[total:])
|
||||
this.packetId = src[total : total+2]
|
||||
total += 2
|
||||
|
||||
remlen := int(this.remlen) - (total - hn)
|
||||
for remlen > 0 {
|
||||
t, n, err := readLPBytes(src[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
this.topics = append(this.topics, t)
|
||||
remlen = remlen - n - 1
|
||||
}
|
||||
|
||||
if len(this.topics) == 0 {
|
||||
return 0, fmt.Errorf("unsubscribe/Decode: Empty topic list")
|
||||
}
|
||||
|
||||
this.dirty = false
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// Encode returns an io.Reader in which the encoded bytes can be read. The second
|
||||
// return value is the number of bytes encoded, so the caller knows how many bytes
|
||||
// there will be. If Encode returns an error, then the first two return values
|
||||
// should be considered invalid.
|
||||
// Any changes to the message after Encode() is called will invalidate the io.Reader.
|
||||
func (this *UnsubscribeMessage) Encode(dst []byte) (int, error) {
|
||||
if !this.dirty {
|
||||
if len(dst) < len(this.dbuf) {
|
||||
return 0, fmt.Errorf("unsubscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", len(this.dbuf), len(dst))
|
||||
}
|
||||
|
||||
return copy(dst, this.dbuf), nil
|
||||
}
|
||||
|
||||
hl := this.header.msglen()
|
||||
ml := this.msglen()
|
||||
|
||||
if len(dst) < hl+ml {
|
||||
return 0, fmt.Errorf("unsubscribe/Encode: Insufficient buffer size. Expecting %d, got %d.", hl+ml, len(dst))
|
||||
}
|
||||
|
||||
if err := this.SetRemainingLength(int32(ml)); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
total := 0
|
||||
|
||||
n, err := this.header.encode(dst[total:])
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
|
||||
if this.PacketId() == 0 {
|
||||
this.SetPacketId(uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff))
|
||||
//this.packetId = uint16(atomic.AddUint64(&gPacketId, 1) & 0xffff)
|
||||
}
|
||||
|
||||
n = copy(dst[total:], this.packetId)
|
||||
//binary.BigEndian.PutUint16(dst[total:], this.packetId)
|
||||
total += n
|
||||
|
||||
for _, t := range this.topics {
|
||||
n, err := writeLPBytes(dst[total:], t)
|
||||
total += n
|
||||
if err != nil {
|
||||
return total, err
|
||||
}
|
||||
}
|
||||
|
||||
return total, nil
|
||||
}
|
||||
|
||||
func (this *UnsubscribeMessage) msglen() int {
|
||||
// packet ID
|
||||
total := 2
|
||||
|
||||
for _, t := range this.topics {
|
||||
total += 2 + len(t)
|
||||
}
|
||||
|
||||
return total
|
||||
}
|
||||
@@ -1,155 +0,0 @@
|
||||
// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package message
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUnsubscribeMessageFields(t *testing.T) {
|
||||
msg := NewUnsubscribeMessage()
|
||||
|
||||
msg.SetPacketId(100)
|
||||
require.Equal(t, 100, int(msg.PacketId()), "Error setting packet ID.")
|
||||
|
||||
msg.AddTopic([]byte("/a/b/#/c"))
|
||||
require.Equal(t, 1, len(msg.Topics()), "Error adding topic.")
|
||||
|
||||
msg.AddTopic([]byte("/a/b/#/c"))
|
||||
require.Equal(t, 1, len(msg.Topics()), "Error adding duplicate topic.")
|
||||
|
||||
msg.RemoveTopic([]byte("/a/b/#/c"))
|
||||
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
|
||||
|
||||
require.False(t, msg.TopicExists([]byte("a/b")), "Topic should not exist.")
|
||||
|
||||
msg.RemoveTopic([]byte("/a/b/#/c"))
|
||||
require.False(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic should not exist.")
|
||||
}
|
||||
|
||||
func TestUnsubscribeMessageDecode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBSCRIBE<<4) | 2,
|
||||
33,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // topic name MSB (0)
|
||||
8, // topic name LSB (8)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
||||
0, // topic name MSB (0)
|
||||
10, // topic name LSB (10)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
||||
}
|
||||
|
||||
msg := NewUnsubscribeMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, UNSUBSCRIBE, msg.Type(), "Error decoding message.")
|
||||
require.Equal(t, 3, len(msg.Topics()), "Error decoding topics.")
|
||||
require.True(t, msg.TopicExists([]byte("surgemq")), "Topic 'surgemq' should exist.")
|
||||
require.True(t, msg.TopicExists([]byte("/a/b/#/c")), "Topic '/a/b/#/c' should exist.")
|
||||
require.True(t, msg.TopicExists([]byte("/a/b/#/cdd")), "Topic '/a/b/#/c' should exist.")
|
||||
}
|
||||
|
||||
// test empty topic list
|
||||
func TestUnsubscribeMessageDecode2(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBSCRIBE<<4) | 2,
|
||||
2,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
}
|
||||
|
||||
msg := NewUnsubscribeMessage()
|
||||
_, err := msg.Decode(msgBytes)
|
||||
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestUnsubscribeMessageEncode(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBSCRIBE<<4) | 2,
|
||||
33,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // topic name MSB (0)
|
||||
8, // topic name LSB (8)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
||||
0, // topic name MSB (0)
|
||||
10, // topic name LSB (10)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
||||
}
|
||||
|
||||
msg := NewUnsubscribeMessage()
|
||||
msg.SetPacketId(7)
|
||||
msg.AddTopic([]byte("surgemq"))
|
||||
msg.AddTopic([]byte("/a/b/#/c"))
|
||||
msg.AddTopic([]byte("/a/b/#/cdd"))
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n], "Error decoding message.")
|
||||
}
|
||||
|
||||
// test to ensure encoding and decoding are the same
|
||||
// decode, encode, and decode again
|
||||
func TestUnsubscribeDecodeEncodeEquiv(t *testing.T) {
|
||||
msgBytes := []byte{
|
||||
byte(UNSUBSCRIBE<<4) | 2,
|
||||
33,
|
||||
0, // packet ID MSB (0)
|
||||
7, // packet ID LSB (7)
|
||||
0, // topic name MSB (0)
|
||||
7, // topic name LSB (7)
|
||||
's', 'u', 'r', 'g', 'e', 'm', 'q',
|
||||
0, // topic name MSB (0)
|
||||
8, // topic name LSB (8)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c',
|
||||
0, // topic name MSB (0)
|
||||
10, // topic name LSB (10)
|
||||
'/', 'a', '/', 'b', '/', '#', '/', 'c', 'd', 'd',
|
||||
}
|
||||
|
||||
msg := NewUnsubscribeMessage()
|
||||
n, err := msg.Decode(msgBytes)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n, "Error decoding message.")
|
||||
|
||||
dst := make([]byte, 100)
|
||||
n2, err := msg.Encode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n2, "Error decoding message.")
|
||||
require.Equal(t, msgBytes, dst[:n2], "Error decoding message.")
|
||||
|
||||
n3, err := msg.Decode(dst)
|
||||
|
||||
require.NoError(t, err, "Error decoding message.")
|
||||
require.Equal(t, len(msgBytes), n3, "Error decoding message.")
|
||||
}
|
||||
64
logger/logger.go
Normal file
64
logger/logger.go
Normal file
@@ -0,0 +1,64 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
|
||||
package logger
|
||||
|
||||
import (
|
||||
"go.uber.org/zap"
|
||||
"go.uber.org/zap/zapcore"
|
||||
)
|
||||
|
||||
var (
|
||||
// env can be setup at build time with Go Linker. Value could be prod or whatever else for dev env
|
||||
instance *zap.Logger
|
||||
logCfg zap.Config
|
||||
encoderCfg = zap.NewProductionEncoderConfig()
|
||||
)
|
||||
|
||||
func init() {
|
||||
encoderCfg.TimeKey = "timestamp"
|
||||
encoderCfg.EncodeTime = zapcore.ISO8601TimeEncoder
|
||||
}
|
||||
|
||||
// NewDevLogger return a logger for dev builds
|
||||
func NewDevLogger() (*zap.Logger, error) {
|
||||
logCfg := zap.NewProductionConfig()
|
||||
logCfg.Level = zap.NewAtomicLevelAt(zap.DebugLevel)
|
||||
// logCfg.DisableStacktrace = true
|
||||
logCfg.EncoderConfig = encoderCfg
|
||||
return logCfg.Build()
|
||||
}
|
||||
|
||||
// NewProdLogger return a logger for production builds
|
||||
func NewProdLogger() (*zap.Logger, error) {
|
||||
logCfg := zap.NewProductionConfig()
|
||||
logCfg.DisableStacktrace = true
|
||||
logCfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel)
|
||||
logCfg.EncoderConfig = encoderCfg
|
||||
return logCfg.Build()
|
||||
}
|
||||
|
||||
func Prod() *zap.Logger {
|
||||
|
||||
l, _ := NewProdLogger()
|
||||
instance = l
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func Debug() *zap.Logger {
|
||||
|
||||
l, _ := NewDevLogger()
|
||||
instance = l
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
func Get() *zap.Logger {
|
||||
if instance == nil {
|
||||
l, _ := NewProdLogger()
|
||||
instance = l
|
||||
}
|
||||
|
||||
return instance
|
||||
}
|
||||
33
logger/logger_test.go
Normal file
33
logger/logger_test.go
Normal file
@@ -0,0 +1,33 @@
|
||||
/*
|
||||
Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
*/
|
||||
package logger
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
func TestGet(t *testing.T) {
|
||||
var l *zap.Logger
|
||||
logger := Get()
|
||||
|
||||
assert.NotNil(t, logger)
|
||||
assert.IsType(t, l, logger)
|
||||
}
|
||||
|
||||
func TestNewDevLogger(t *testing.T) {
|
||||
logger, err := NewDevLogger()
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.True(t, logger.Core().Enabled(zap.DebugLevel))
|
||||
}
|
||||
|
||||
func TestNewProdLogger(t *testing.T) {
|
||||
logger, err := NewProdLogger()
|
||||
|
||||
assert.Nil(t, err)
|
||||
assert.False(t, logger.Core().Enabled(zap.DebugLevel))
|
||||
}
|
||||
28
main.go
28
main.go
@@ -1,29 +1,35 @@
|
||||
/* Copyright (c) 2018, joy.zhou <chowyu08@gmail.com>
|
||||
|
||||
Permission to use, copy, modify, and/or distribute this software for any
|
||||
purpose with or without fee is hereby granted, provided that the above
|
||||
copyright notice and this permission notice appear in all copies.
|
||||
*/
|
||||
package main
|
||||
|
||||
import (
|
||||
"hmq/broker"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"runtime"
|
||||
|
||||
log "github.com/cihub/seelog"
|
||||
"github.com/fhmq/hmq/broker"
|
||||
)
|
||||
|
||||
func main() {
|
||||
config, er := broker.LoadConfig()
|
||||
if er != nil {
|
||||
log.Error("Load Config file error: ", er)
|
||||
return
|
||||
runtime.GOMAXPROCS(runtime.NumCPU())
|
||||
config, err := broker.ConfigureConfig(os.Args[1:])
|
||||
if err != nil {
|
||||
log.Fatal("configure broker config error: ", err)
|
||||
}
|
||||
|
||||
broker, err := broker.NewBroker(config)
|
||||
b, err := broker.NewBroker(config)
|
||||
if err != nil {
|
||||
log.Error("New Broker error: ", er)
|
||||
return
|
||||
log.Fatal("New Broker error: ", err)
|
||||
}
|
||||
broker.Start()
|
||||
b.Start()
|
||||
|
||||
s := waitForSignal()
|
||||
log.Infof("signal got: %v ,broker closed.", s)
|
||||
log.Println("signal received, broker closed.", s)
|
||||
}
|
||||
|
||||
func waitForSignal() os.Signal {
|
||||
|
||||
23
plugins/auth/auth.go
Normal file
23
plugins/auth/auth.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"github.com/fhmq/hmq/plugins/auth/authhttp"
|
||||
)
|
||||
|
||||
const (
|
||||
AuthHTTP = "authhttp"
|
||||
)
|
||||
|
||||
type Auth interface {
|
||||
CheckACL(action, username, topic string) bool
|
||||
CheckConnect(clientID, username, password string) bool
|
||||
}
|
||||
|
||||
func NewAuth(name string) Auth {
|
||||
switch name {
|
||||
case AuthHTTP:
|
||||
return authhttp.Init()
|
||||
default:
|
||||
return &mockAuth{}
|
||||
}
|
||||
}
|
||||
179
plugins/auth/authhttp/authhttp.go
Normal file
179
plugins/auth/authhttp/authhttp.go
Normal file
@@ -0,0 +1,179 @@
|
||||
package authhttp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fhmq/hmq/logger"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
//Config device kafka config
|
||||
type Config struct {
|
||||
AuthURL string `json:"auth"`
|
||||
ACLURL string `json:"acl"`
|
||||
SuperURL string `json:"super"`
|
||||
}
|
||||
|
||||
type authHTTP struct {
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
var (
|
||||
config Config
|
||||
log = logger.Get().Named("authhttp")
|
||||
httpClient *http.Client
|
||||
)
|
||||
|
||||
//Init init kafak client
|
||||
func Init() *authHTTP {
|
||||
content, err := ioutil.ReadFile("./plugins/auth/authhttp/http.json")
|
||||
if err != nil {
|
||||
log.Fatal("Read config file error: ", zap.Error(err))
|
||||
}
|
||||
// log.Info(string(content))
|
||||
|
||||
err = json.Unmarshal(content, &config)
|
||||
if err != nil {
|
||||
log.Fatal("Unmarshal config file error: ", zap.Error(err))
|
||||
}
|
||||
// fmt.Println("http: config: ", config)
|
||||
|
||||
httpClient = &http.Client{
|
||||
Transport: &http.Transport{
|
||||
MaxConnsPerHost: 100,
|
||||
MaxIdleConns: 100,
|
||||
MaxIdleConnsPerHost: 100,
|
||||
},
|
||||
Timeout: time.Second * 100,
|
||||
}
|
||||
return &authHTTP{client: httpClient}
|
||||
}
|
||||
|
||||
//CheckAuth check mqtt connect
|
||||
func (a *authHTTP) CheckConnect(clientID, username, password string) bool {
|
||||
action := "connect"
|
||||
{
|
||||
aCache := checkCache(action, clientID, username, password, "")
|
||||
if aCache != nil {
|
||||
if aCache.password == password && aCache.username == username && aCache.action == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
data := url.Values{}
|
||||
data.Add("username", username)
|
||||
data.Add("clientid", clientID)
|
||||
data.Add("password", password)
|
||||
|
||||
req, err := http.NewRequest("POST", config.AuthURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
log.Error("new request super: ", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
|
||||
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
log.Error("request super: ", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
io.Copy(ioutil.Discard, resp.Body)
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
addCache(action, clientID, username, password, "")
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// //CheckSuper check mqtt connect
|
||||
// func CheckSuper(clientID, username, password string) bool {
|
||||
// action := "connect"
|
||||
// {
|
||||
// aCache := checkCache(action, clientID, username, password, "")
|
||||
// if aCache != nil {
|
||||
// if aCache.password == password && aCache.username == username && aCache.action == action {
|
||||
// return true
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// data := url.Values{}
|
||||
// data.Add("username", username)
|
||||
// data.Add("clientid", clientID)
|
||||
// data.Add("password", password)
|
||||
|
||||
// req, err := http.NewRequest("POST", config.SuperURL, strings.NewReader(data.Encode()))
|
||||
// if err != nil {
|
||||
// log.Error("new request super: ", zap.Error(err))
|
||||
// return false
|
||||
// }
|
||||
// req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
// req.Header.Add("Content-Length", strconv.Itoa(len(data.Encode())))
|
||||
|
||||
// resp, err := httpClient.Do(req)
|
||||
// if err != nil {
|
||||
// log.Error("request super: ", zap.Error(err))
|
||||
// return false
|
||||
// }
|
||||
|
||||
// defer resp.Body.Close()
|
||||
// io.Copy(ioutil.Discard, resp.Body)
|
||||
|
||||
// if resp.StatusCode == http.StatusOK {
|
||||
// return true
|
||||
// }
|
||||
// return false
|
||||
// }
|
||||
|
||||
//CheckACL check mqtt connect
|
||||
func (a *authHTTP) CheckACL(username, access, topic string) bool {
|
||||
action := access
|
||||
{
|
||||
aCache := checkCache(action, "", username, "", topic)
|
||||
if aCache != nil {
|
||||
if aCache.topic == topic && aCache.action == action {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest("GET", config.ACLURL, nil)
|
||||
if err != nil {
|
||||
log.Error("get acl: ", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
data := req.URL.Query()
|
||||
|
||||
data.Add("username", username)
|
||||
data.Add("topic", topic)
|
||||
data.Add("access", access)
|
||||
req.URL.RawQuery = data.Encode()
|
||||
// fmt.Println("req:", req)
|
||||
resp, err := a.client.Do(req)
|
||||
if err != nil {
|
||||
log.Error("request acl: ", zap.Error(err))
|
||||
return false
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
io.Copy(ioutil.Discard, resp.Body)
|
||||
|
||||
if resp.StatusCode == http.StatusOK {
|
||||
addCache(action, "", username, "", topic)
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
32
plugins/auth/authhttp/cache.go
Normal file
32
plugins/auth/authhttp/cache.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package authhttp
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/patrickmn/go-cache"
|
||||
)
|
||||
|
||||
type authCache struct {
|
||||
action string
|
||||
username string
|
||||
clientID string
|
||||
password string
|
||||
topic string
|
||||
}
|
||||
|
||||
var (
|
||||
// cache = make(map[string]authCache)
|
||||
c = cache.New(5*time.Minute, 10*time.Minute)
|
||||
)
|
||||
|
||||
func checkCache(action, clientID, username, password, topic string) *authCache {
|
||||
authc, found := c.Get(username)
|
||||
if found {
|
||||
return authc.(*authCache)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addCache(action, clientID, username, password, topic string) {
|
||||
c.Set(username, &authCache{action: action, username: username, clientID: clientID, password: password, topic: topic}, cache.DefaultExpiration)
|
||||
}
|
||||
5
plugins/auth/authhttp/http.json
Normal file
5
plugins/auth/authhttp/http.json
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"auth": "http://127.0.0.1:9090/mqtt/auth",
|
||||
"acl": "http://127.0.0.1:9090/mqtt/acl",
|
||||
"super": "http://127.0.0.1:9090/mqtt/superuser"
|
||||
}
|
||||
11
plugins/auth/mock.go
Normal file
11
plugins/auth/mock.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package auth
|
||||
|
||||
type mockAuth struct{}
|
||||
|
||||
func (m *mockAuth) CheckACL(action, username, topic string) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *mockAuth) CheckConnect(clientID, username, password string) bool {
|
||||
return true
|
||||
}
|
||||
49
plugins/bridge/bridge.go
Normal file
49
plugins/bridge/bridge.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package bridge
|
||||
|
||||
import "github.com/fhmq/hmq/logger"
|
||||
|
||||
const (
|
||||
//Connect mqtt connect
|
||||
Connect = "connect"
|
||||
//Publish mqtt publish
|
||||
Publish = "publish"
|
||||
//Subscribe mqtt sub
|
||||
Subscribe = "subscribe"
|
||||
//Unsubscribe mqtt sub
|
||||
Unsubscribe = "unsubscribe"
|
||||
//Disconnect mqtt disconenct
|
||||
Disconnect = "disconnect"
|
||||
)
|
||||
|
||||
var (
|
||||
log = logger.Get().Named("bridge")
|
||||
)
|
||||
|
||||
//Elements kafka publish elements
|
||||
type Elements struct {
|
||||
ClientID string `json:"clientid"`
|
||||
Username string `json:"username"`
|
||||
Topic string `json:"topic"`
|
||||
Payload string `json:"payload"`
|
||||
Timestamp int64 `json:"ts"`
|
||||
Size int32 `json:"size"`
|
||||
Action string `json:"action"`
|
||||
}
|
||||
|
||||
const (
|
||||
//Kafka plugin name
|
||||
Kafka = "kafka"
|
||||
)
|
||||
|
||||
type BridgeMQ interface {
|
||||
Publish(e *Elements) error
|
||||
}
|
||||
|
||||
func NewBridgeMQ(name string) BridgeMQ {
|
||||
switch name {
|
||||
case Kafka:
|
||||
return InitKafka()
|
||||
default:
|
||||
return &mockMQ{}
|
||||
}
|
||||
}
|
||||
120
plugins/bridge/kafka.go
Normal file
120
plugins/bridge/kafka.go
Normal file
@@ -0,0 +1,120 @@
|
||||
package bridge
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"regexp"
|
||||
|
||||
"github.com/Shopify/sarama"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
|
||||
type kafakConfig struct {
|
||||
Addr []string `json:"addr"`
|
||||
ConnectTopic string `json:"onConnect"`
|
||||
SubscribeTopic string `json:"onSubscribe"`
|
||||
PublishTopic string `json:"onPublish"`
|
||||
UnsubscribeTopic string `json:"onUnsubscribe"`
|
||||
DisconnectTopic string `json:"onDisconnect"`
|
||||
RegexpMap map[string]string `json:"regexpMap"`
|
||||
}
|
||||
|
||||
type kafka struct {
|
||||
kafakConfig kafakConfig
|
||||
kafkaClient sarama.AsyncProducer
|
||||
}
|
||||
|
||||
//Init init kafak client
|
||||
func InitKafka() *kafka {
|
||||
log.Info("start connect kafka....")
|
||||
content, err := ioutil.ReadFile("./plugins/mq/kafka/kafka.json")
|
||||
if err != nil {
|
||||
log.Fatal("Read config file error: ", zap.Error(err))
|
||||
}
|
||||
// log.Info(string(content))
|
||||
var config kafakConfig
|
||||
err = json.Unmarshal(content, &config)
|
||||
if err != nil {
|
||||
log.Fatal("Unmarshal config file error: ", zap.Error(err))
|
||||
}
|
||||
c := &kafka{kafakConfig: config}
|
||||
c.connect()
|
||||
return c
|
||||
}
|
||||
|
||||
//connect
|
||||
func (k *kafka) connect() {
|
||||
conf := sarama.NewConfig()
|
||||
conf.Version = sarama.V1_1_1_0
|
||||
kafkaClient, err := sarama.NewAsyncProducer(k.kafakConfig.Addr, conf)
|
||||
if err != nil {
|
||||
log.Fatal("create kafka async producer failed: ", zap.Error(err))
|
||||
}
|
||||
|
||||
go func() {
|
||||
for err := range kafkaClient.Errors() {
|
||||
log.Error("send msg to kafka failed: ", zap.Error(err))
|
||||
}
|
||||
}()
|
||||
|
||||
k.kafkaClient = kafkaClient
|
||||
}
|
||||
|
||||
//Publish publish to kafka
|
||||
func (k *kafka) Publish(e *Elements) error {
|
||||
config := k.kafakConfig
|
||||
key := e.ClientID
|
||||
var topics []string
|
||||
switch e.Action {
|
||||
case Connect:
|
||||
if config.ConnectTopic != "" {
|
||||
topics = append(topics, config.ConnectTopic)
|
||||
}
|
||||
case Publish:
|
||||
if config.PublishTopic != "" {
|
||||
topics = append(topics, config.PublishTopic)
|
||||
}
|
||||
// foreach regexp map config
|
||||
for reg, topic := range config.RegexpMap {
|
||||
match, _ := regexp.MatchString(reg, e.Topic)
|
||||
if match {
|
||||
topics = append(topics, topic)
|
||||
}
|
||||
}
|
||||
case Subscribe:
|
||||
if config.SubscribeTopic != "" {
|
||||
topics = append(topics, config.SubscribeTopic)
|
||||
}
|
||||
case Unsubscribe:
|
||||
if config.UnsubscribeTopic != "" {
|
||||
topics = append(topics, config.UnsubscribeTopic)
|
||||
}
|
||||
case Disconnect:
|
||||
if config.DisconnectTopic != "" {
|
||||
topics = append(topics, config.DisconnectTopic)
|
||||
}
|
||||
default:
|
||||
return errors.New("error action: " + e.Action)
|
||||
}
|
||||
|
||||
return k.publish(topics, key, e)
|
||||
|
||||
}
|
||||
|
||||
func (k *kafka) publish(topics []string, key string, msg *Elements) error {
|
||||
payload, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, topic := range topics {
|
||||
k.kafkaClient.Input() <- &sarama.ProducerMessage{
|
||||
Topic: topic,
|
||||
Key: sarama.ByteEncoder(key),
|
||||
Value: sarama.ByteEncoder(payload),
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
15
plugins/bridge/kafka/kafka.json
Normal file
15
plugins/bridge/kafka/kafka.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"addr": [
|
||||
"127.0.0.1:9090"
|
||||
],
|
||||
"onConnect": "onConnect",
|
||||
"onPublish": "onPublish",
|
||||
"onSubscribe": "onSubscribe",
|
||||
"onDisconnect": "onDisconnect",
|
||||
"onUnsubscribe": "onUnsubscribe",
|
||||
"regexpMap": [
|
||||
{
|
||||
"^/(.+)/(.+)/upload/(.*)$": "upload"
|
||||
}
|
||||
]
|
||||
}
|
||||
7
plugins/bridge/mock.go
Normal file
7
plugins/bridge/mock.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package bridge
|
||||
|
||||
type mockMQ struct{}
|
||||
|
||||
func (m *mockMQ) Publish(e *Elements) error {
|
||||
return nil
|
||||
}
|
||||
58
pool/fixpool.go
Normal file
58
pool/fixpool.go
Normal file
@@ -0,0 +1,58 @@
|
||||
package pool
|
||||
|
||||
import (
|
||||
"github.com/segmentio/fasthash/fnv1a"
|
||||
)
|
||||
|
||||
type WorkerPool struct {
|
||||
maxWorkers int
|
||||
taskQueue []chan func()
|
||||
stoppedChan chan struct{}
|
||||
}
|
||||
|
||||
func New(maxWorkers int) *WorkerPool {
|
||||
// There must be at least one worker.
|
||||
if maxWorkers < 1 {
|
||||
maxWorkers = 1
|
||||
}
|
||||
|
||||
// taskQueue is unbuffered since items are always removed immediately.
|
||||
pool := &WorkerPool{
|
||||
taskQueue: make([]chan func(), maxWorkers),
|
||||
maxWorkers: maxWorkers,
|
||||
stoppedChan: make(chan struct{}),
|
||||
}
|
||||
// Start the task dispatcher.
|
||||
pool.dispatch()
|
||||
|
||||
return pool
|
||||
}
|
||||
|
||||
func (p *WorkerPool) Submit(uid string, task func()) {
|
||||
idx := fnv1a.HashString64(uid) % uint64(p.maxWorkers)
|
||||
if task != nil {
|
||||
p.taskQueue[idx] <- task
|
||||
}
|
||||
}
|
||||
|
||||
func (p *WorkerPool) dispatch() {
|
||||
for i := 0; i < p.maxWorkers; i++ {
|
||||
p.taskQueue[i] = make(chan func(), 1024)
|
||||
go startWorker(p.taskQueue[i])
|
||||
}
|
||||
}
|
||||
|
||||
func startWorker(taskChan chan func()) {
|
||||
go func() {
|
||||
var task func()
|
||||
var ok bool
|
||||
for {
|
||||
task, ok = <-taskChan
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
// Execute the task.
|
||||
task()
|
||||
}
|
||||
}()
|
||||
}
|
||||
166
pool/pool.go
Normal file
166
pool/pool.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package pool
|
||||
|
||||
// import "time"
|
||||
|
||||
// const (
|
||||
// // This value is the size of the queue that workers register their
|
||||
// // availability to the dispatcher. There may be hundreds of workers, but
|
||||
// // only a small channel is needed to register some of the workers.
|
||||
// readyQueueSize = 64
|
||||
|
||||
// // If worker pool receives no new work for this period of time, then stop
|
||||
// // a worker goroutine.
|
||||
// idleTimeoutSec = 5
|
||||
// )
|
||||
|
||||
// type WorkerPool struct {
|
||||
// maxWorkers int
|
||||
// timeout time.Duration
|
||||
// taskQueue chan func()
|
||||
// readyWorkers chan chan func()
|
||||
// stoppedChan chan struct{}
|
||||
// }
|
||||
|
||||
// func New(maxWorkers int) *WorkerPool {
|
||||
// // There must be at least one worker.
|
||||
// if maxWorkers < 1 {
|
||||
// maxWorkers = 1
|
||||
// }
|
||||
|
||||
// // taskQueue is unbuffered since items are always removed immediately.
|
||||
// pool := &WorkerPool{
|
||||
// taskQueue: make(chan func()),
|
||||
// maxWorkers: maxWorkers,
|
||||
// readyWorkers: make(chan chan func(), readyQueueSize),
|
||||
// timeout: time.Second * idleTimeoutSec,
|
||||
// stoppedChan: make(chan struct{}),
|
||||
// }
|
||||
|
||||
// // Start the task dispatcher.
|
||||
// go pool.dispatch()
|
||||
|
||||
// return pool
|
||||
// }
|
||||
|
||||
// func (p *WorkerPool) Stop() {
|
||||
// if p.Stopped() {
|
||||
// return
|
||||
// }
|
||||
// close(p.taskQueue)
|
||||
// <-p.stoppedChan
|
||||
// }
|
||||
|
||||
// func (p *WorkerPool) Stopped() bool {
|
||||
// select {
|
||||
// case <-p.stoppedChan:
|
||||
// return true
|
||||
// default:
|
||||
// }
|
||||
// return false
|
||||
// }
|
||||
|
||||
// func (p *WorkerPool) Submit(task func()) {
|
||||
// if task != nil {
|
||||
// p.taskQueue <- task
|
||||
// }
|
||||
// }
|
||||
|
||||
// func (p *WorkerPool) SubmitWait(task func()) {
|
||||
// if task == nil {
|
||||
// return
|
||||
// }
|
||||
// doneChan := make(chan struct{})
|
||||
// p.taskQueue <- func() {
|
||||
// task()
|
||||
// close(doneChan)
|
||||
// }
|
||||
// <-doneChan
|
||||
// }
|
||||
|
||||
// func (p *WorkerPool) dispatch() {
|
||||
// defer close(p.stoppedChan)
|
||||
// timeout := time.NewTimer(p.timeout)
|
||||
// var workerCount int
|
||||
// var task func()
|
||||
// var ok bool
|
||||
// var workerTaskChan chan func()
|
||||
// startReady := make(chan chan func())
|
||||
// Loop:
|
||||
// for {
|
||||
// timeout.Reset(p.timeout)
|
||||
// select {
|
||||
// case task, ok = <-p.taskQueue:
|
||||
// if !ok {
|
||||
// break Loop
|
||||
// }
|
||||
// // Got a task to do.
|
||||
// select {
|
||||
// case workerTaskChan = <-p.readyWorkers:
|
||||
// // A worker is ready, so give task to worker.
|
||||
// workerTaskChan <- task
|
||||
// default:
|
||||
// // No workers ready.
|
||||
// // Create a new worker, if not at max.
|
||||
// if workerCount < p.maxWorkers {
|
||||
// workerCount++
|
||||
// go func(t func()) {
|
||||
// startWorker(startReady, p.readyWorkers)
|
||||
// // Submit the task when the new worker.
|
||||
// taskChan := <-startReady
|
||||
// taskChan <- t
|
||||
// }(task)
|
||||
// } else {
|
||||
// // Start a goroutine to submit the task when an existing
|
||||
// // worker is ready.
|
||||
// go func(t func()) {
|
||||
// taskChan := <-p.readyWorkers
|
||||
// taskChan <- t
|
||||
// }(task)
|
||||
// }
|
||||
// }
|
||||
// case <-timeout.C:
|
||||
// // Timed out waiting for work to arrive. Kill a ready worker.
|
||||
// if workerCount > 0 {
|
||||
// select {
|
||||
// case workerTaskChan = <-p.readyWorkers:
|
||||
// // A worker is ready, so kill.
|
||||
// close(workerTaskChan)
|
||||
// workerCount--
|
||||
// default:
|
||||
// // No work, but no ready workers. All workers are busy.
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Stop all remaining workers as they become ready.
|
||||
// for workerCount > 0 {
|
||||
// workerTaskChan = <-p.readyWorkers
|
||||
// close(workerTaskChan)
|
||||
// workerCount--
|
||||
// }
|
||||
// }
|
||||
|
||||
// func startWorker(startReady, readyWorkers chan chan func()) {
|
||||
// go func() {
|
||||
// taskChan := make(chan func())
|
||||
// var task func()
|
||||
// var ok bool
|
||||
// // Register availability on starReady channel.
|
||||
// startReady <- taskChan
|
||||
// for {
|
||||
// // Read task from dispatcher.
|
||||
// task, ok = <-taskChan
|
||||
// if !ok {
|
||||
// // Dispatcher has told worker to stop.
|
||||
// break
|
||||
// }
|
||||
|
||||
// // Execute the task.
|
||||
// task()
|
||||
|
||||
// // Register availability on readyWorkers channel.
|
||||
// readyWorkers <- taskChan
|
||||
// }
|
||||
// }()
|
||||
// }
|
||||
31
vendor/github.com/DataDog/zstd/.travis.yml
generated
vendored
Normal file
31
vendor/github.com/DataDog/zstd/.travis.yml
generated
vendored
Normal file
@@ -0,0 +1,31 @@
|
||||
dist: xenial
|
||||
language: go
|
||||
|
||||
go:
|
||||
- 1.10.x
|
||||
- 1.11.x
|
||||
- 1.12.x
|
||||
|
||||
os:
|
||||
- linux
|
||||
- osx
|
||||
|
||||
matrix:
|
||||
include:
|
||||
name: "Go 1.11.x CentOS 32bits"
|
||||
language: go
|
||||
go: 1.11.x
|
||||
os: linux
|
||||
services:
|
||||
- docker
|
||||
script:
|
||||
# Please update Go version in travis_test_32 as needed
|
||||
- "docker run -i -v \"${PWD}:/zstd\" toopher/centos-i386:centos6 /bin/bash -c \"linux32 --32bit i386 /zstd/travis_test_32.sh\""
|
||||
|
||||
install:
|
||||
- "wget https://github.com/DataDog/zstd/files/2246767/mr.zip"
|
||||
- "unzip mr.zip"
|
||||
script:
|
||||
- "go build"
|
||||
- "PAYLOAD=`pwd`/mr go test -v"
|
||||
- "PAYLOAD=`pwd`/mr go test -bench ."
|
||||
27
vendor/github.com/DataDog/zstd/LICENSE
generated
vendored
Normal file
27
vendor/github.com/DataDog/zstd/LICENSE
generated
vendored
Normal file
@@ -0,0 +1,27 @@
|
||||
Simplified BSD License
|
||||
|
||||
Copyright (c) 2016, Datadog <info@datadoghq.com>
|
||||
All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice,
|
||||
this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
* Neither the name of the copyright holder nor the names of its contributors
|
||||
may be used to endorse or promote products derived from this software
|
||||
without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
120
vendor/github.com/DataDog/zstd/README.md
generated
vendored
Normal file
120
vendor/github.com/DataDog/zstd/README.md
generated
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
# Zstd Go Wrapper
|
||||
|
||||
[C Zstd Homepage](https://github.com/Cyan4973/zstd)
|
||||
|
||||
The current headers and C files are from *v1.3.8* (Commit
|
||||
[470344d](https://github.com/facebook/zstd/releases/tag/v1.3.8)).
|
||||
|
||||
## Usage
|
||||
|
||||
There are two main APIs:
|
||||
|
||||
* simple Compress/Decompress
|
||||
* streaming API (io.Reader/io.Writer)
|
||||
|
||||
The compress/decompress APIs mirror that of lz4, while the streaming API was
|
||||
designed to be a drop-in replacement for zlib.
|
||||
|
||||
### Simple `Compress/Decompress`
|
||||
|
||||
|
||||
```go
|
||||
// Compress compresses the byte array given in src and writes it to dst.
|
||||
// If you already have a buffer allocated, you can pass it to prevent allocation
|
||||
// If not, you can pass nil as dst.
|
||||
// If the buffer is too small, it will be reallocated, resized, and returned bu the function
|
||||
// If dst is nil, this will allocate the worst case size (CompressBound(src))
|
||||
Compress(dst, src []byte) ([]byte, error)
|
||||
```
|
||||
|
||||
```go
|
||||
// CompressLevel is the same as Compress but you can pass another compression level
|
||||
CompressLevel(dst, src []byte, level int) ([]byte, error)
|
||||
```
|
||||
|
||||
```go
|
||||
// Decompress will decompress your payload into dst.
|
||||
// If you already have a buffer allocated, you can pass it to prevent allocation
|
||||
// If not, you can pass nil as dst (allocates a 4*src size as default).
|
||||
// If the buffer is too small, it will retry 3 times by doubling the dst size
|
||||
// After max retries, it will switch to the slower stream API to be sure to be able
|
||||
// to decompress. Currently switches if compression ratio > 4*2**3=32.
|
||||
Decompress(dst, src []byte) ([]byte, error)
|
||||
```
|
||||
|
||||
### Stream API
|
||||
|
||||
```go
|
||||
// NewWriter creates a new object that can optionally be initialized with
|
||||
// a precomputed dictionary. If dict is nil, compress without a dictionary.
|
||||
// The dictionary array should not be changed during the use of this object.
|
||||
// You MUST CALL Close() to write the last bytes of a zstd stream and free C objects.
|
||||
NewWriter(w io.Writer) *Writer
|
||||
NewWriterLevel(w io.Writer, level int) *Writer
|
||||
NewWriterLevelDict(w io.Writer, level int, dict []byte) *Writer
|
||||
|
||||
// Write compresses the input data and write it to the underlying writer
|
||||
(w *Writer) Write(p []byte) (int, error)
|
||||
|
||||
// Close flushes the buffer and frees C zstd objects
|
||||
(w *Writer) Close() error
|
||||
```
|
||||
|
||||
```go
|
||||
// NewReader returns a new io.ReadCloser that will decompress data from the
|
||||
// underlying reader. If a dictionary is provided to NewReaderDict, it must
|
||||
// not be modified until Close is called. It is the caller's responsibility
|
||||
// to call Close, which frees up C objects.
|
||||
NewReader(r io.Reader) io.ReadCloser
|
||||
NewReaderDict(r io.Reader, dict []byte) io.ReadCloser
|
||||
```
|
||||
|
||||
### Benchmarks (benchmarked with v0.5.0)
|
||||
|
||||
The author of Zstd also wrote lz4. Zstd is intended to occupy a speed/ratio
|
||||
level similar to what zlib currently provides. In our tests, the can always
|
||||
be made to be better than zlib by chosing an appropriate level while still
|
||||
keeping compression and decompression time faster than zlib.
|
||||
|
||||
You can run the benchmarks against your own payloads by using the Go benchmarks tool.
|
||||
Just export your payload filepath as the `PAYLOAD` environment variable and run the benchmarks:
|
||||
|
||||
```go
|
||||
go test -bench .
|
||||
```
|
||||
|
||||
Compression of a 7Mb pdf zstd (this wrapper) vs [czlib](https://github.com/DataDog/czlib):
|
||||
```
|
||||
BenchmarkCompression 5 221056624 ns/op 67.34 MB/s
|
||||
BenchmarkDecompression 100 18370416 ns/op 810.32 MB/s
|
||||
|
||||
BenchmarkFzlibCompress 2 610156603 ns/op 24.40 MB/s
|
||||
BenchmarkFzlibDecompress 20 81195246 ns/op 183.33 MB/s
|
||||
```
|
||||
|
||||
Ratio is also better by a margin of ~20%.
|
||||
Compression speed is always better than zlib on all the payloads we tested;
|
||||
However, [czlib](https://github.com/DataDog/czlib) has optimisations that make it
|
||||
faster at decompressiong small payloads:
|
||||
|
||||
```
|
||||
Testing with size: 11... czlib: 8.97 MB/s, zstd: 3.26 MB/s
|
||||
Testing with size: 27... czlib: 23.3 MB/s, zstd: 8.22 MB/s
|
||||
Testing with size: 62... czlib: 31.6 MB/s, zstd: 19.49 MB/s
|
||||
Testing with size: 141... czlib: 74.54 MB/s, zstd: 42.55 MB/s
|
||||
Testing with size: 323... czlib: 155.14 MB/s, zstd: 99.39 MB/s
|
||||
Testing with size: 739... czlib: 235.9 MB/s, zstd: 216.45 MB/s
|
||||
Testing with size: 1689... czlib: 116.45 MB/s, zstd: 345.64 MB/s
|
||||
Testing with size: 3858... czlib: 176.39 MB/s, zstd: 617.56 MB/s
|
||||
Testing with size: 8811... czlib: 254.11 MB/s, zstd: 824.34 MB/s
|
||||
Testing with size: 20121... czlib: 197.43 MB/s, zstd: 1339.11 MB/s
|
||||
Testing with size: 45951... czlib: 201.62 MB/s, zstd: 1951.57 MB/s
|
||||
```
|
||||
|
||||
zstd starts to shine with payloads > 1KB
|
||||
|
||||
### Stability - Current state: STABLE
|
||||
|
||||
The C library seems to be pretty stable and according to the author has been tested and fuzzed.
|
||||
|
||||
For the Go wrapper, the test cover most usual cases and we have succesfully tested it on all staging and prod data.
|
||||
30
vendor/github.com/DataDog/zstd/ZSTD_LICENSE
generated
vendored
Normal file
30
vendor/github.com/DataDog/zstd/ZSTD_LICENSE
generated
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
BSD License
|
||||
|
||||
For Zstandard software
|
||||
|
||||
Copyright (c) 2016-present, Facebook, Inc. All rights reserved.
|
||||
|
||||
Redistribution and use in source and binary forms, with or without modification,
|
||||
are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name Facebook nor the names of its contributors may be used to
|
||||
endorse or promote products derived from this software without specific
|
||||
prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
||||
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
||||
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
||||
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
||||
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
||||
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
||||
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
||||
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
455
vendor/github.com/DataDog/zstd/bitstream.h
generated
vendored
Normal file
455
vendor/github.com/DataDog/zstd/bitstream.h
generated
vendored
Normal file
@@ -0,0 +1,455 @@
|
||||
/* ******************************************************************
|
||||
bitstream
|
||||
Part of FSE library
|
||||
Copyright (C) 2013-present, Yann Collet.
|
||||
|
||||
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
You can contact the author at :
|
||||
- Source repository : https://github.com/Cyan4973/FiniteStateEntropy
|
||||
****************************************************************** */
|
||||
#ifndef BITSTREAM_H_MODULE
|
||||
#define BITSTREAM_H_MODULE
|
||||
|
||||
#if defined (__cplusplus)
|
||||
extern "C" {
|
||||
#endif
|
||||
|
||||
/*
|
||||
* This API consists of small unitary functions, which must be inlined for best performance.
|
||||
* Since link-time-optimization is not available for all compilers,
|
||||
* these functions are defined into a .h to be included.
|
||||
*/
|
||||
|
||||
/*-****************************************
|
||||
* Dependencies
|
||||
******************************************/
|
||||
#include "mem.h" /* unaligned access routines */
|
||||
#include "debug.h" /* assert(), DEBUGLOG(), RAWLOG() */
|
||||
#include "error_private.h" /* error codes and messages */
|
||||
|
||||
|
||||
/*=========================================
|
||||
* Target specific
|
||||
=========================================*/
|
||||
#if defined(__BMI__) && defined(__GNUC__)
|
||||
# include <immintrin.h> /* support for bextr (experimental) */
|
||||
#endif
|
||||
|
||||
#define STREAM_ACCUMULATOR_MIN_32 25
|
||||
#define STREAM_ACCUMULATOR_MIN_64 57
|
||||
#define STREAM_ACCUMULATOR_MIN ((U32)(MEM_32bits() ? STREAM_ACCUMULATOR_MIN_32 : STREAM_ACCUMULATOR_MIN_64))
|
||||
|
||||
|
||||
/*-******************************************
|
||||
* bitStream encoding API (write forward)
|
||||
********************************************/
|
||||
/* bitStream can mix input from multiple sources.
|
||||
* A critical property of these streams is that they encode and decode in **reverse** direction.
|
||||
* So the first bit sequence you add will be the last to be read, like a LIFO stack.
|
||||
*/
|
||||
typedef struct {
|
||||
size_t bitContainer;
|
||||
unsigned bitPos;
|
||||
char* startPtr;
|
||||
char* ptr;
|
||||
char* endPtr;
|
||||
} BIT_CStream_t;
|
||||
|
||||
MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC, void* dstBuffer, size_t dstCapacity);
|
||||
MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC, size_t value, unsigned nbBits);
|
||||
MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC);
|
||||
MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC);
|
||||
|
||||
/* Start with initCStream, providing the size of buffer to write into.
|
||||
* bitStream will never write outside of this buffer.
|
||||
* `dstCapacity` must be >= sizeof(bitD->bitContainer), otherwise @return will be an error code.
|
||||
*
|
||||
* bits are first added to a local register.
|
||||
* Local register is size_t, hence 64-bits on 64-bits systems, or 32-bits on 32-bits systems.
|
||||
* Writing data into memory is an explicit operation, performed by the flushBits function.
|
||||
* Hence keep track how many bits are potentially stored into local register to avoid register overflow.
|
||||
* After a flushBits, a maximum of 7 bits might still be stored into local register.
|
||||
*
|
||||
* Avoid storing elements of more than 24 bits if you want compatibility with 32-bits bitstream readers.
|
||||
*
|
||||
* Last operation is to close the bitStream.
|
||||
* The function returns the final size of CStream in bytes.
|
||||
* If data couldn't fit into `dstBuffer`, it will return a 0 ( == not storable)
|
||||
*/
|
||||
|
||||
|
||||
/*-********************************************
|
||||
* bitStream decoding API (read backward)
|
||||
**********************************************/
|
||||
typedef struct {
|
||||
size_t bitContainer;
|
||||
unsigned bitsConsumed;
|
||||
const char* ptr;
|
||||
const char* start;
|
||||
const char* limitPtr;
|
||||
} BIT_DStream_t;
|
||||
|
||||
typedef enum { BIT_DStream_unfinished = 0,
|
||||
BIT_DStream_endOfBuffer = 1,
|
||||
BIT_DStream_completed = 2,
|
||||
BIT_DStream_overflow = 3 } BIT_DStream_status; /* result of BIT_reloadDStream() */
|
||||
/* 1,2,4,8 would be better for bitmap combinations, but slows down performance a bit ... :( */
|
||||
|
||||
MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize);
|
||||
MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits);
|
||||
MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD);
|
||||
MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* bitD);
|
||||
|
||||
|
||||
/* Start by invoking BIT_initDStream().
|
||||
* A chunk of the bitStream is then stored into a local register.
|
||||
* Local register size is 64-bits on 64-bits systems, 32-bits on 32-bits systems (size_t).
|
||||
* You can then retrieve bitFields stored into the local register, **in reverse order**.
|
||||
* Local register is explicitly reloaded from memory by the BIT_reloadDStream() method.
|
||||
* A reload guarantee a minimum of ((8*sizeof(bitD->bitContainer))-7) bits when its result is BIT_DStream_unfinished.
|
||||
* Otherwise, it can be less than that, so proceed accordingly.
|
||||
* Checking if DStream has reached its end can be performed with BIT_endOfDStream().
|
||||
*/
|
||||
|
||||
|
||||
/*-****************************************
|
||||
* unsafe API
|
||||
******************************************/
|
||||
MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC, size_t value, unsigned nbBits);
|
||||
/* faster, but works only if value is "clean", meaning all high bits above nbBits are 0 */
|
||||
|
||||
MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC);
|
||||
/* unsafe version; does not check buffer overflow */
|
||||
|
||||
MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits);
|
||||
/* faster, but works only if nbBits >= 1 */
|
||||
|
||||
|
||||
|
||||
/*-**************************************************************
|
||||
* Internal functions
|
||||
****************************************************************/
|
||||
MEM_STATIC unsigned BIT_highbit32 (U32 val)
|
||||
{
|
||||
assert(val != 0);
|
||||
{
|
||||
# if defined(_MSC_VER) /* Visual */
|
||||
unsigned long r=0;
|
||||
_BitScanReverse ( &r, val );
|
||||
return (unsigned) r;
|
||||
# elif defined(__GNUC__) && (__GNUC__ >= 3) /* Use GCC Intrinsic */
|
||||
return 31 - __builtin_clz (val);
|
||||
# else /* Software version */
|
||||
static const unsigned DeBruijnClz[32] = { 0, 9, 1, 10, 13, 21, 2, 29,
|
||||
11, 14, 16, 18, 22, 25, 3, 30,
|
||||
8, 12, 20, 28, 15, 17, 24, 7,
|
||||
19, 27, 23, 6, 26, 5, 4, 31 };
|
||||
U32 v = val;
|
||||
v |= v >> 1;
|
||||
v |= v >> 2;
|
||||
v |= v >> 4;
|
||||
v |= v >> 8;
|
||||
v |= v >> 16;
|
||||
return DeBruijnClz[ (U32) (v * 0x07C4ACDDU) >> 27];
|
||||
# endif
|
||||
}
|
||||
}
|
||||
|
||||
/*===== Local Constants =====*/
|
||||
static const unsigned BIT_mask[] = {
|
||||
0, 1, 3, 7, 0xF, 0x1F,
|
||||
0x3F, 0x7F, 0xFF, 0x1FF, 0x3FF, 0x7FF,
|
||||
0xFFF, 0x1FFF, 0x3FFF, 0x7FFF, 0xFFFF, 0x1FFFF,
|
||||
0x3FFFF, 0x7FFFF, 0xFFFFF, 0x1FFFFF, 0x3FFFFF, 0x7FFFFF,
|
||||
0xFFFFFF, 0x1FFFFFF, 0x3FFFFFF, 0x7FFFFFF, 0xFFFFFFF, 0x1FFFFFFF,
|
||||
0x3FFFFFFF, 0x7FFFFFFF}; /* up to 31 bits */
|
||||
#define BIT_MASK_SIZE (sizeof(BIT_mask) / sizeof(BIT_mask[0]))
|
||||
|
||||
/*-**************************************************************
|
||||
* bitStream encoding
|
||||
****************************************************************/
|
||||
/*! BIT_initCStream() :
|
||||
* `dstCapacity` must be > sizeof(size_t)
|
||||
* @return : 0 if success,
|
||||
* otherwise an error code (can be tested using ERR_isError()) */
|
||||
MEM_STATIC size_t BIT_initCStream(BIT_CStream_t* bitC,
|
||||
void* startPtr, size_t dstCapacity)
|
||||
{
|
||||
bitC->bitContainer = 0;
|
||||
bitC->bitPos = 0;
|
||||
bitC->startPtr = (char*)startPtr;
|
||||
bitC->ptr = bitC->startPtr;
|
||||
bitC->endPtr = bitC->startPtr + dstCapacity - sizeof(bitC->bitContainer);
|
||||
if (dstCapacity <= sizeof(bitC->bitContainer)) return ERROR(dstSize_tooSmall);
|
||||
return 0;
|
||||
}
|
||||
|
||||
/*! BIT_addBits() :
|
||||
* can add up to 31 bits into `bitC`.
|
||||
* Note : does not check for register overflow ! */
|
||||
MEM_STATIC void BIT_addBits(BIT_CStream_t* bitC,
|
||||
size_t value, unsigned nbBits)
|
||||
{
|
||||
MEM_STATIC_ASSERT(BIT_MASK_SIZE == 32);
|
||||
assert(nbBits < BIT_MASK_SIZE);
|
||||
assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
bitC->bitContainer |= (value & BIT_mask[nbBits]) << bitC->bitPos;
|
||||
bitC->bitPos += nbBits;
|
||||
}
|
||||
|
||||
/*! BIT_addBitsFast() :
|
||||
* works only if `value` is _clean_,
|
||||
* meaning all high bits above nbBits are 0 */
|
||||
MEM_STATIC void BIT_addBitsFast(BIT_CStream_t* bitC,
|
||||
size_t value, unsigned nbBits)
|
||||
{
|
||||
assert((value>>nbBits) == 0);
|
||||
assert(nbBits + bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
bitC->bitContainer |= value << bitC->bitPos;
|
||||
bitC->bitPos += nbBits;
|
||||
}
|
||||
|
||||
/*! BIT_flushBitsFast() :
|
||||
* assumption : bitContainer has not overflowed
|
||||
* unsafe version; does not check buffer overflow */
|
||||
MEM_STATIC void BIT_flushBitsFast(BIT_CStream_t* bitC)
|
||||
{
|
||||
size_t const nbBytes = bitC->bitPos >> 3;
|
||||
assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
MEM_writeLEST(bitC->ptr, bitC->bitContainer);
|
||||
bitC->ptr += nbBytes;
|
||||
assert(bitC->ptr <= bitC->endPtr);
|
||||
bitC->bitPos &= 7;
|
||||
bitC->bitContainer >>= nbBytes*8;
|
||||
}
|
||||
|
||||
/*! BIT_flushBits() :
|
||||
* assumption : bitContainer has not overflowed
|
||||
* safe version; check for buffer overflow, and prevents it.
|
||||
* note : does not signal buffer overflow.
|
||||
* overflow will be revealed later on using BIT_closeCStream() */
|
||||
MEM_STATIC void BIT_flushBits(BIT_CStream_t* bitC)
|
||||
{
|
||||
size_t const nbBytes = bitC->bitPos >> 3;
|
||||
assert(bitC->bitPos < sizeof(bitC->bitContainer) * 8);
|
||||
MEM_writeLEST(bitC->ptr, bitC->bitContainer);
|
||||
bitC->ptr += nbBytes;
|
||||
if (bitC->ptr > bitC->endPtr) bitC->ptr = bitC->endPtr;
|
||||
bitC->bitPos &= 7;
|
||||
bitC->bitContainer >>= nbBytes*8;
|
||||
}
|
||||
|
||||
/*! BIT_closeCStream() :
|
||||
* @return : size of CStream, in bytes,
|
||||
* or 0 if it could not fit into dstBuffer */
|
||||
MEM_STATIC size_t BIT_closeCStream(BIT_CStream_t* bitC)
|
||||
{
|
||||
BIT_addBitsFast(bitC, 1, 1); /* endMark */
|
||||
BIT_flushBits(bitC);
|
||||
if (bitC->ptr >= bitC->endPtr) return 0; /* overflow detected */
|
||||
return (bitC->ptr - bitC->startPtr) + (bitC->bitPos > 0);
|
||||
}
|
||||
|
||||
|
||||
/*-********************************************************
|
||||
* bitStream decoding
|
||||
**********************************************************/
|
||||
/*! BIT_initDStream() :
|
||||
* Initialize a BIT_DStream_t.
|
||||
* `bitD` : a pointer to an already allocated BIT_DStream_t structure.
|
||||
* `srcSize` must be the *exact* size of the bitStream, in bytes.
|
||||
* @return : size of stream (== srcSize), or an errorCode if a problem is detected
|
||||
*/
|
||||
MEM_STATIC size_t BIT_initDStream(BIT_DStream_t* bitD, const void* srcBuffer, size_t srcSize)
|
||||
{
|
||||
if (srcSize < 1) { memset(bitD, 0, sizeof(*bitD)); return ERROR(srcSize_wrong); }
|
||||
|
||||
bitD->start = (const char*)srcBuffer;
|
||||
bitD->limitPtr = bitD->start + sizeof(bitD->bitContainer);
|
||||
|
||||
if (srcSize >= sizeof(bitD->bitContainer)) { /* normal case */
|
||||
bitD->ptr = (const char*)srcBuffer + srcSize - sizeof(bitD->bitContainer);
|
||||
bitD->bitContainer = MEM_readLEST(bitD->ptr);
|
||||
{ BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
|
||||
bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0; /* ensures bitsConsumed is always set */
|
||||
if (lastByte == 0) return ERROR(GENERIC); /* endMark not present */ }
|
||||
} else {
|
||||
bitD->ptr = bitD->start;
|
||||
bitD->bitContainer = *(const BYTE*)(bitD->start);
|
||||
switch(srcSize)
|
||||
{
|
||||
case 7: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[6]) << (sizeof(bitD->bitContainer)*8 - 16);
|
||||
/* fall-through */
|
||||
|
||||
case 6: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[5]) << (sizeof(bitD->bitContainer)*8 - 24);
|
||||
/* fall-through */
|
||||
|
||||
case 5: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[4]) << (sizeof(bitD->bitContainer)*8 - 32);
|
||||
/* fall-through */
|
||||
|
||||
case 4: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[3]) << 24;
|
||||
/* fall-through */
|
||||
|
||||
case 3: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[2]) << 16;
|
||||
/* fall-through */
|
||||
|
||||
case 2: bitD->bitContainer += (size_t)(((const BYTE*)(srcBuffer))[1]) << 8;
|
||||
/* fall-through */
|
||||
|
||||
default: break;
|
||||
}
|
||||
{ BYTE const lastByte = ((const BYTE*)srcBuffer)[srcSize-1];
|
||||
bitD->bitsConsumed = lastByte ? 8 - BIT_highbit32(lastByte) : 0;
|
||||
if (lastByte == 0) return ERROR(corruption_detected); /* endMark not present */
|
||||
}
|
||||
bitD->bitsConsumed += (U32)(sizeof(bitD->bitContainer) - srcSize)*8;
|
||||
}
|
||||
|
||||
return srcSize;
|
||||
}
|
||||
|
||||
MEM_STATIC size_t BIT_getUpperBits(size_t bitContainer, U32 const start)
|
||||
{
|
||||
return bitContainer >> start;
|
||||
}
|
||||
|
||||
MEM_STATIC size_t BIT_getMiddleBits(size_t bitContainer, U32 const start, U32 const nbBits)
|
||||
{
|
||||
U32 const regMask = sizeof(bitContainer)*8 - 1;
|
||||
/* if start > regMask, bitstream is corrupted, and result is undefined */
|
||||
assert(nbBits < BIT_MASK_SIZE);
|
||||
return (bitContainer >> (start & regMask)) & BIT_mask[nbBits];
|
||||
}
|
||||
|
||||
MEM_STATIC size_t BIT_getLowerBits(size_t bitContainer, U32 const nbBits)
|
||||
{
|
||||
assert(nbBits < BIT_MASK_SIZE);
|
||||
return bitContainer & BIT_mask[nbBits];
|
||||
}
|
||||
|
||||
/*! BIT_lookBits() :
|
||||
* Provides next n bits from local register.
|
||||
* local register is not modified.
|
||||
* On 32-bits, maxNbBits==24.
|
||||
* On 64-bits, maxNbBits==56.
|
||||
* @return : value extracted */
|
||||
MEM_STATIC size_t BIT_lookBits(const BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
/* arbitrate between double-shift and shift+mask */
|
||||
#if 1
|
||||
/* if bitD->bitsConsumed + nbBits > sizeof(bitD->bitContainer)*8,
|
||||
* bitstream is likely corrupted, and result is undefined */
|
||||
return BIT_getMiddleBits(bitD->bitContainer, (sizeof(bitD->bitContainer)*8) - bitD->bitsConsumed - nbBits, nbBits);
|
||||
#else
|
||||
/* this code path is slower on my os-x laptop */
|
||||
U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
|
||||
return ((bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> 1) >> ((regMask-nbBits) & regMask);
|
||||
#endif
|
||||
}
|
||||
|
||||
/*! BIT_lookBitsFast() :
|
||||
* unsafe version; only works if nbBits >= 1 */
|
||||
MEM_STATIC size_t BIT_lookBitsFast(const BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
U32 const regMask = sizeof(bitD->bitContainer)*8 - 1;
|
||||
assert(nbBits >= 1);
|
||||
return (bitD->bitContainer << (bitD->bitsConsumed & regMask)) >> (((regMask+1)-nbBits) & regMask);
|
||||
}
|
||||
|
||||
MEM_STATIC void BIT_skipBits(BIT_DStream_t* bitD, U32 nbBits)
|
||||
{
|
||||
bitD->bitsConsumed += nbBits;
|
||||
}
|
||||
|
||||
/*! BIT_readBits() :
|
||||
* Read (consume) next n bits from local register and update.
|
||||
* Pay attention to not read more than nbBits contained into local register.
|
||||
* @return : extracted value. */
|
||||
MEM_STATIC size_t BIT_readBits(BIT_DStream_t* bitD, unsigned nbBits)
|
||||
{
|
||||
size_t const value = BIT_lookBits(bitD, nbBits);
|
||||
BIT_skipBits(bitD, nbBits);
|
||||
return value;
|
||||
}
|
||||
|
||||
/*! BIT_readBitsFast() :
|
||||
* unsafe version; only works only if nbBits >= 1 */
|
||||
MEM_STATIC size_t BIT_readBitsFast(BIT_DStream_t* bitD, unsigned nbBits)
|
||||
{
|
||||
size_t const value = BIT_lookBitsFast(bitD, nbBits);
|
||||
assert(nbBits >= 1);
|
||||
BIT_skipBits(bitD, nbBits);
|
||||
return value;
|
||||
}
|
||||
|
||||
/*! BIT_reloadDStream() :
|
||||
* Refill `bitD` from buffer previously set in BIT_initDStream() .
|
||||
* This function is safe, it guarantees it will not read beyond src buffer.
|
||||
* @return : status of `BIT_DStream_t` internal register.
|
||||
* when status == BIT_DStream_unfinished, internal register is filled with at least 25 or 57 bits */
|
||||
MEM_STATIC BIT_DStream_status BIT_reloadDStream(BIT_DStream_t* bitD)
|
||||
{
|
||||
if (bitD->bitsConsumed > (sizeof(bitD->bitContainer)*8)) /* overflow detected, like end of stream */
|
||||
return BIT_DStream_overflow;
|
||||
|
||||
if (bitD->ptr >= bitD->limitPtr) {
|
||||
bitD->ptr -= bitD->bitsConsumed >> 3;
|
||||
bitD->bitsConsumed &= 7;
|
||||
bitD->bitContainer = MEM_readLEST(bitD->ptr);
|
||||
return BIT_DStream_unfinished;
|
||||
}
|
||||
if (bitD->ptr == bitD->start) {
|
||||
if (bitD->bitsConsumed < sizeof(bitD->bitContainer)*8) return BIT_DStream_endOfBuffer;
|
||||
return BIT_DStream_completed;
|
||||
}
|
||||
/* start < ptr < limitPtr */
|
||||
{ U32 nbBytes = bitD->bitsConsumed >> 3;
|
||||
BIT_DStream_status result = BIT_DStream_unfinished;
|
||||
if (bitD->ptr - nbBytes < bitD->start) {
|
||||
nbBytes = (U32)(bitD->ptr - bitD->start); /* ptr > start */
|
||||
result = BIT_DStream_endOfBuffer;
|
||||
}
|
||||
bitD->ptr -= nbBytes;
|
||||
bitD->bitsConsumed -= nbBytes*8;
|
||||
bitD->bitContainer = MEM_readLEST(bitD->ptr); /* reminder : srcSize > sizeof(bitD->bitContainer), otherwise bitD->ptr == bitD->start */
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
/*! BIT_endOfDStream() :
|
||||
* @return : 1 if DStream has _exactly_ reached its end (all bits consumed).
|
||||
*/
|
||||
MEM_STATIC unsigned BIT_endOfDStream(const BIT_DStream_t* DStream)
|
||||
{
|
||||
return ((DStream->ptr == DStream->start) && (DStream->bitsConsumed == sizeof(DStream->bitContainer)*8));
|
||||
}
|
||||
|
||||
#if defined (__cplusplus)
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif /* BITSTREAM_H_MODULE */
|
||||
140
vendor/github.com/DataDog/zstd/compiler.h
generated
vendored
Normal file
140
vendor/github.com/DataDog/zstd/compiler.h
generated
vendored
Normal file
@@ -0,0 +1,140 @@
|
||||
/*
|
||||
* Copyright (c) 2016-present, Yann Collet, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under both the BSD-style license (found in the
|
||||
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
|
||||
* in the COPYING file in the root directory of this source tree).
|
||||
* You may select, at your option, one of the above-listed licenses.
|
||||
*/
|
||||
|
||||
#ifndef ZSTD_COMPILER_H
|
||||
#define ZSTD_COMPILER_H
|
||||
|
||||
/*-*******************************************************
|
||||
* Compiler specifics
|
||||
*********************************************************/
|
||||
/* force inlining */
|
||||
|
||||
#if !defined(ZSTD_NO_INLINE)
|
||||
#if defined (__GNUC__) || defined(__cplusplus) || defined(__STDC_VERSION__) && __STDC_VERSION__ >= 199901L /* C99 */
|
||||
# define INLINE_KEYWORD inline
|
||||
#else
|
||||
# define INLINE_KEYWORD
|
||||
#endif
|
||||
|
||||
#if defined(__GNUC__)
|
||||
# define FORCE_INLINE_ATTR __attribute__((always_inline))
|
||||
#elif defined(_MSC_VER)
|
||||
# define FORCE_INLINE_ATTR __forceinline
|
||||
#else
|
||||
# define FORCE_INLINE_ATTR
|
||||
#endif
|
||||
|
||||
#else
|
||||
|
||||
#define INLINE_KEYWORD
|
||||
#define FORCE_INLINE_ATTR
|
||||
|
||||
#endif
|
||||
|
||||
/**
|
||||
* FORCE_INLINE_TEMPLATE is used to define C "templates", which take constant
|
||||
* parameters. They must be inlined for the compiler to elimininate the constant
|
||||
* branches.
|
||||
*/
|
||||
#define FORCE_INLINE_TEMPLATE static INLINE_KEYWORD FORCE_INLINE_ATTR
|
||||
/**
|
||||
* HINT_INLINE is used to help the compiler generate better code. It is *not*
|
||||
* used for "templates", so it can be tweaked based on the compilers
|
||||
* performance.
|
||||
*
|
||||
* gcc-4.8 and gcc-4.9 have been shown to benefit from leaving off the
|
||||
* always_inline attribute.
|
||||
*
|
||||
* clang up to 5.0.0 (trunk) benefit tremendously from the always_inline
|
||||
* attribute.
|
||||
*/
|
||||
#if !defined(__clang__) && defined(__GNUC__) && __GNUC__ >= 4 && __GNUC_MINOR__ >= 8 && __GNUC__ < 5
|
||||
# define HINT_INLINE static INLINE_KEYWORD
|
||||
#else
|
||||
# define HINT_INLINE static INLINE_KEYWORD FORCE_INLINE_ATTR
|
||||
#endif
|
||||
|
||||
/* force no inlining */
|
||||
#ifdef _MSC_VER
|
||||
# define FORCE_NOINLINE static __declspec(noinline)
|
||||
#else
|
||||
# ifdef __GNUC__
|
||||
# define FORCE_NOINLINE static __attribute__((__noinline__))
|
||||
# else
|
||||
# define FORCE_NOINLINE static
|
||||
# endif
|
||||
#endif
|
||||
|
||||
/* target attribute */
|
||||
#ifndef __has_attribute
|
||||
#define __has_attribute(x) 0 /* Compatibility with non-clang compilers. */
|
||||
#endif
|
||||
#if defined(__GNUC__)
|
||||
# define TARGET_ATTRIBUTE(target) __attribute__((__target__(target)))
|
||||
#else
|
||||
# define TARGET_ATTRIBUTE(target)
|
||||
#endif
|
||||
|
||||
/* Enable runtime BMI2 dispatch based on the CPU.
|
||||
* Enabled for clang & gcc >=4.8 on x86 when BMI2 isn't enabled by default.
|
||||
*/
|
||||
#ifndef DYNAMIC_BMI2
|
||||
#if ((defined(__clang__) && __has_attribute(__target__)) \
|
||||
|| (defined(__GNUC__) \
|
||||
&& (__GNUC__ >= 5 || (__GNUC__ == 4 && __GNUC_MINOR__ >= 8)))) \
|
||||
&& (defined(__x86_64__) || defined(_M_X86)) \
|
||||
&& !defined(__BMI2__)
|
||||
# define DYNAMIC_BMI2 1
|
||||
#else
|
||||
# define DYNAMIC_BMI2 0
|
||||
#endif
|
||||
#endif
|
||||
|
||||
/* prefetch
|
||||
* can be disabled, by declaring NO_PREFETCH build macro */
|
||||
#if defined(NO_PREFETCH)
|
||||
# define PREFETCH_L1(ptr) (void)(ptr) /* disabled */
|
||||
# define PREFETCH_L2(ptr) (void)(ptr) /* disabled */
|
||||
#else
|
||||
# if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_I86)) /* _mm_prefetch() is not defined outside of x86/x64 */
|
||||
# include <mmintrin.h> /* https://msdn.microsoft.com/fr-fr/library/84szxsww(v=vs.90).aspx */
|
||||
# define PREFETCH_L1(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T0)
|
||||
# define PREFETCH_L2(ptr) _mm_prefetch((const char*)(ptr), _MM_HINT_T1)
|
||||
# elif defined(__GNUC__) && ( (__GNUC__ >= 4) || ( (__GNUC__ == 3) && (__GNUC_MINOR__ >= 1) ) )
|
||||
# define PREFETCH_L1(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 3 /* locality */)
|
||||
# define PREFETCH_L2(ptr) __builtin_prefetch((ptr), 0 /* rw==read */, 2 /* locality */)
|
||||
# else
|
||||
# define PREFETCH_L1(ptr) (void)(ptr) /* disabled */
|
||||
# define PREFETCH_L2(ptr) (void)(ptr) /* disabled */
|
||||
# endif
|
||||
#endif /* NO_PREFETCH */
|
||||
|
||||
#define CACHELINE_SIZE 64
|
||||
|
||||
#define PREFETCH_AREA(p, s) { \
|
||||
const char* const _ptr = (const char*)(p); \
|
||||
size_t const _size = (size_t)(s); \
|
||||
size_t _pos; \
|
||||
for (_pos=0; _pos<_size; _pos+=CACHELINE_SIZE) { \
|
||||
PREFETCH_L2(_ptr + _pos); \
|
||||
} \
|
||||
}
|
||||
|
||||
/* disable warnings */
|
||||
#ifdef _MSC_VER /* Visual Studio */
|
||||
# include <intrin.h> /* For Visual 2005 */
|
||||
# pragma warning(disable : 4100) /* disable: C4100: unreferenced formal parameter */
|
||||
# pragma warning(disable : 4127) /* disable: C4127: conditional expression is constant */
|
||||
# pragma warning(disable : 4204) /* disable: C4204: non-constant aggregate initializer */
|
||||
# pragma warning(disable : 4214) /* disable: C4214: non-int bitfields */
|
||||
# pragma warning(disable : 4324) /* disable: C4324: padded structure */
|
||||
#endif
|
||||
|
||||
#endif /* ZSTD_COMPILER_H */
|
||||
1081
vendor/github.com/DataDog/zstd/cover.c
generated
vendored
Normal file
1081
vendor/github.com/DataDog/zstd/cover.c
generated
vendored
Normal file
File diff suppressed because it is too large
Load Diff
83
vendor/github.com/DataDog/zstd/cover.h
generated
vendored
Normal file
83
vendor/github.com/DataDog/zstd/cover.h
generated
vendored
Normal file
@@ -0,0 +1,83 @@
|
||||
#include <stdio.h> /* fprintf */
|
||||
#include <stdlib.h> /* malloc, free, qsort */
|
||||
#include <string.h> /* memset */
|
||||
#include <time.h> /* clock */
|
||||
#include "mem.h" /* read */
|
||||
#include "pool.h"
|
||||
#include "threading.h"
|
||||
#include "zstd_internal.h" /* includes zstd.h */
|
||||
#ifndef ZDICT_STATIC_LINKING_ONLY
|
||||
#define ZDICT_STATIC_LINKING_ONLY
|
||||
#endif
|
||||
#include "zdict.h"
|
||||
|
||||
/**
|
||||
* COVER_best_t is used for two purposes:
|
||||
* 1. Synchronizing threads.
|
||||
* 2. Saving the best parameters and dictionary.
|
||||
*
|
||||
* All of the methods except COVER_best_init() are thread safe if zstd is
|
||||
* compiled with multithreaded support.
|
||||
*/
|
||||
typedef struct COVER_best_s {
|
||||
ZSTD_pthread_mutex_t mutex;
|
||||
ZSTD_pthread_cond_t cond;
|
||||
size_t liveJobs;
|
||||
void *dict;
|
||||
size_t dictSize;
|
||||
ZDICT_cover_params_t parameters;
|
||||
size_t compressedSize;
|
||||
} COVER_best_t;
|
||||
|
||||
/**
|
||||
* A segment is a range in the source as well as the score of the segment.
|
||||
*/
|
||||
typedef struct {
|
||||
U32 begin;
|
||||
U32 end;
|
||||
U32 score;
|
||||
} COVER_segment_t;
|
||||
|
||||
/**
|
||||
* Checks total compressed size of a dictionary
|
||||
*/
|
||||
size_t COVER_checkTotalCompressedSize(const ZDICT_cover_params_t parameters,
|
||||
const size_t *samplesSizes, const BYTE *samples,
|
||||
size_t *offsets,
|
||||
size_t nbTrainSamples, size_t nbSamples,
|
||||
BYTE *const dict, size_t dictBufferCapacity);
|
||||
|
||||
/**
|
||||
* Returns the sum of the sample sizes.
|
||||
*/
|
||||
size_t COVER_sum(const size_t *samplesSizes, unsigned nbSamples) ;
|
||||
|
||||
/**
|
||||
* Initialize the `COVER_best_t`.
|
||||
*/
|
||||
void COVER_best_init(COVER_best_t *best);
|
||||
|
||||
/**
|
||||
* Wait until liveJobs == 0.
|
||||
*/
|
||||
void COVER_best_wait(COVER_best_t *best);
|
||||
|
||||
/**
|
||||
* Call COVER_best_wait() and then destroy the COVER_best_t.
|
||||
*/
|
||||
void COVER_best_destroy(COVER_best_t *best);
|
||||
|
||||
/**
|
||||
* Called when a thread is about to be launched.
|
||||
* Increments liveJobs.
|
||||
*/
|
||||
void COVER_best_start(COVER_best_t *best);
|
||||
|
||||
/**
|
||||
* Called when a thread finishes executing, both on error or success.
|
||||
* Decrements liveJobs and signals any waiting threads if liveJobs == 0.
|
||||
* If this dictionary is the best so far save it and its parameters.
|
||||
*/
|
||||
void COVER_best_finish(COVER_best_t *best, size_t compressedSize,
|
||||
ZDICT_cover_params_t parameters, void *dict,
|
||||
size_t dictSize);
|
||||
215
vendor/github.com/DataDog/zstd/cpu.h
generated
vendored
Normal file
215
vendor/github.com/DataDog/zstd/cpu.h
generated
vendored
Normal file
@@ -0,0 +1,215 @@
|
||||
/*
|
||||
* Copyright (c) 2018-present, Facebook, Inc.
|
||||
* All rights reserved.
|
||||
*
|
||||
* This source code is licensed under both the BSD-style license (found in the
|
||||
* LICENSE file in the root directory of this source tree) and the GPLv2 (found
|
||||
* in the COPYING file in the root directory of this source tree).
|
||||
* You may select, at your option, one of the above-listed licenses.
|
||||
*/
|
||||
|
||||
#ifndef ZSTD_COMMON_CPU_H
|
||||
#define ZSTD_COMMON_CPU_H
|
||||
|
||||
/**
|
||||
* Implementation taken from folly/CpuId.h
|
||||
* https://github.com/facebook/folly/blob/master/folly/CpuId.h
|
||||
*/
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include "mem.h"
|
||||
|
||||
#ifdef _MSC_VER
|
||||
#include <intrin.h>
|
||||
#endif
|
||||
|
||||
typedef struct {
|
||||
U32 f1c;
|
||||
U32 f1d;
|
||||
U32 f7b;
|
||||
U32 f7c;
|
||||
} ZSTD_cpuid_t;
|
||||
|
||||
MEM_STATIC ZSTD_cpuid_t ZSTD_cpuid(void) {
|
||||
U32 f1c = 0;
|
||||
U32 f1d = 0;
|
||||
U32 f7b = 0;
|
||||
U32 f7c = 0;
|
||||
#if defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86))
|
||||
int reg[4];
|
||||
__cpuid((int*)reg, 0);
|
||||
{
|
||||
int const n = reg[0];
|
||||
if (n >= 1) {
|
||||
__cpuid((int*)reg, 1);
|
||||
f1c = (U32)reg[2];
|
||||
f1d = (U32)reg[3];
|
||||
}
|
||||
if (n >= 7) {
|
||||
__cpuidex((int*)reg, 7, 0);
|
||||
f7b = (U32)reg[1];
|
||||
f7c = (U32)reg[2];
|
||||
}
|
||||
}
|
||||
#elif defined(__i386__) && defined(__PIC__) && !defined(__clang__) && defined(__GNUC__)
|
||||
/* The following block like the normal cpuid branch below, but gcc
|
||||
* reserves ebx for use of its pic register so we must specially
|
||||
* handle the save and restore to avoid clobbering the register
|
||||
*/
|
||||
U32 n;
|
||||
__asm__(
|
||||
"pushl %%ebx\n\t"
|
||||
"cpuid\n\t"
|
||||
"popl %%ebx\n\t"
|
||||
: "=a"(n)
|
||||
: "a"(0)
|
||||
: "ecx", "edx");
|
||||
if (n >= 1) {
|
||||
U32 f1a;
|
||||
__asm__(
|
||||
"pushl %%ebx\n\t"
|
||||
"cpuid\n\t"
|
||||
"popl %%ebx\n\t"
|
||||
: "=a"(f1a), "=c"(f1c), "=d"(f1d)
|
||||
: "a"(1));
|
||||
}
|
||||
if (n >= 7) {
|
||||
__asm__(
|
||||
"pushl %%ebx\n\t"
|
||||
"cpuid\n\t"
|
||||
"movl %%ebx, %%eax\n\t"
|
||||
"popl %%ebx"
|
||||
: "=a"(f7b), "=c"(f7c)
|
||||
: "a"(7), "c"(0)
|
||||
: "edx");
|
||||
}
|
||||
#elif defined(__x86_64__) || defined(_M_X64) || defined(__i386__)
|
||||
U32 n;
|
||||
__asm__("cpuid" : "=a"(n) : "a"(0) : "ebx", "ecx", "edx");
|
||||
if (n >= 1) {
|
||||
U32 f1a;
|
||||
__asm__("cpuid" : "=a"(f1a), "=c"(f1c), "=d"(f1d) : "a"(1) : "ebx");
|
||||
}
|
||||
if (n >= 7) {
|
||||
U32 f7a;
|
||||
__asm__("cpuid"
|
||||
: "=a"(f7a), "=b"(f7b), "=c"(f7c)
|
||||
: "a"(7), "c"(0)
|
||||
: "edx");
|
||||
}
|
||||
#endif
|
||||
{
|
||||
ZSTD_cpuid_t cpuid;
|
||||
cpuid.f1c = f1c;
|
||||
cpuid.f1d = f1d;
|
||||
cpuid.f7b = f7b;
|
||||
cpuid.f7c = f7c;
|
||||
return cpuid;
|
||||
}
|
||||
}
|
||||
|
||||
#define X(name, r, bit) \
|
||||
MEM_STATIC int ZSTD_cpuid_##name(ZSTD_cpuid_t const cpuid) { \
|
||||
return ((cpuid.r) & (1U << bit)) != 0; \
|
||||
}
|
||||
|
||||
/* cpuid(1): Processor Info and Feature Bits. */
|
||||
#define C(name, bit) X(name, f1c, bit)
|
||||
C(sse3, 0)
|
||||
C(pclmuldq, 1)
|
||||
C(dtes64, 2)
|
||||
C(monitor, 3)
|
||||
C(dscpl, 4)
|
||||
C(vmx, 5)
|
||||
C(smx, 6)
|
||||
C(eist, 7)
|
||||
C(tm2, 8)
|
||||
C(ssse3, 9)
|
||||
C(cnxtid, 10)
|
||||
C(fma, 12)
|
||||
C(cx16, 13)
|
||||
C(xtpr, 14)
|
||||
C(pdcm, 15)
|
||||
C(pcid, 17)
|
||||
C(dca, 18)
|
||||
C(sse41, 19)
|
||||
C(sse42, 20)
|
||||
C(x2apic, 21)
|
||||
C(movbe, 22)
|
||||
C(popcnt, 23)
|
||||
C(tscdeadline, 24)
|
||||
C(aes, 25)
|
||||
C(xsave, 26)
|
||||
C(osxsave, 27)
|
||||
C(avx, 28)
|
||||
C(f16c, 29)
|
||||
C(rdrand, 30)
|
||||
#undef C
|
||||
#define D(name, bit) X(name, f1d, bit)
|
||||
D(fpu, 0)
|
||||
D(vme, 1)
|
||||
D(de, 2)
|
||||
D(pse, 3)
|
||||
D(tsc, 4)
|
||||
D(msr, 5)
|
||||
D(pae, 6)
|
||||
D(mce, 7)
|
||||
D(cx8, 8)
|
||||
D(apic, 9)
|
||||
D(sep, 11)
|
||||
D(mtrr, 12)
|
||||
D(pge, 13)
|
||||
D(mca, 14)
|
||||
D(cmov, 15)
|
||||
D(pat, 16)
|
||||
D(pse36, 17)
|
||||
D(psn, 18)
|
||||
D(clfsh, 19)
|
||||
D(ds, 21)
|
||||
D(acpi, 22)
|
||||
D(mmx, 23)
|
||||
D(fxsr, 24)
|
||||
D(sse, 25)
|
||||
D(sse2, 26)
|
||||
D(ss, 27)
|
||||
D(htt, 28)
|
||||
D(tm, 29)
|
||||
D(pbe, 31)
|
||||
#undef D
|
||||
|
||||
/* cpuid(7): Extended Features. */
|
||||
#define B(name, bit) X(name, f7b, bit)
|
||||
B(bmi1, 3)
|
||||
B(hle, 4)
|
||||
B(avx2, 5)
|
||||
B(smep, 7)
|
||||
B(bmi2, 8)
|
||||
B(erms, 9)
|
||||
B(invpcid, 10)
|
||||
B(rtm, 11)
|
||||
B(mpx, 14)
|
||||
B(avx512f, 16)
|
||||
B(avx512dq, 17)
|
||||
B(rdseed, 18)
|
||||
B(adx, 19)
|
||||
B(smap, 20)
|
||||
B(avx512ifma, 21)
|
||||
B(pcommit, 22)
|
||||
B(clflushopt, 23)
|
||||
B(clwb, 24)
|
||||
B(avx512pf, 26)
|
||||
B(avx512er, 27)
|
||||
B(avx512cd, 28)
|
||||
B(sha, 29)
|
||||
B(avx512bw, 30)
|
||||
B(avx512vl, 31)
|
||||
#undef B
|
||||
#define C(name, bit) X(name, f7c, bit)
|
||||
C(prefetchwt1, 0)
|
||||
C(avx512vbmi, 1)
|
||||
#undef C
|
||||
|
||||
#undef X
|
||||
|
||||
#endif /* ZSTD_COMMON_CPU_H */
|
||||
44
vendor/github.com/DataDog/zstd/debug.c
generated
vendored
Normal file
44
vendor/github.com/DataDog/zstd/debug.c
generated
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
/* ******************************************************************
|
||||
debug
|
||||
Part of FSE library
|
||||
Copyright (C) 2013-present, Yann Collet.
|
||||
|
||||
BSD 2-Clause License (http://www.opensource.org/licenses/bsd-license.php)
|
||||
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are
|
||||
met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright
|
||||
notice, this list of conditions and the following disclaimer.
|
||||
* Redistributions in binary form must reproduce the above
|
||||
copyright notice, this list of conditions and the following disclaimer
|
||||
in the documentation and/or other materials provided with the
|
||||
distribution.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
||||
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
||||
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
||||
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
||||
OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
||||
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
||||
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
||||
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
||||
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
||||
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
You can contact the author at :
|
||||
- Source repository : https://github.com/Cyan4973/FiniteStateEntropy
|
||||
****************************************************************** */
|
||||
|
||||
|
||||
/*
|
||||
* This module only hosts one global variable
|
||||
* which can be used to dynamically influence the verbosity of traces,
|
||||
* such as DEBUGLOG and RAWLOG
|
||||
*/
|
||||
|
||||
#include "debug.h"
|
||||
|
||||
int g_debuglevel = DEBUGLEVEL;
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user