From d5bf973f53fdc1ba67242dfc33574ef8ebd72753 Mon Sep 17 00:00:00 2001 From: zhouyuyan Date: Mon, 24 Dec 2018 17:03:47 +0800 Subject: [PATCH] modify --- broker/broker.go | 48 ++-- broker/client.go | 263 ++++-------------- broker/sesson.go | 53 ++++ main.go | 3 +- sessions/memprovider.go | 76 ++++++ sessions/redisprovider.go | 95 +++++++ sessions/session.go | 149 +++++++++++ sessions/sessions.go | 92 +++++++ topics/memtopics.go | 549 ++++++++++++++++++++++++++++++++++++++ topics/topics.go | 91 +++++++ 10 files changed, 1198 insertions(+), 221 deletions(-) create mode 100644 broker/sesson.go create mode 100644 sessions/memprovider.go create mode 100644 sessions/redisprovider.go create mode 100644 sessions/session.go create mode 100644 sessions/sessions.go create mode 100644 topics/memtopics.go create mode 100644 topics/topics.go diff --git a/broker/broker.go b/broker/broker.go index 11006f5..ffae754 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -14,6 +14,8 @@ import ( "github.com/eclipse/paho.mqtt.golang/packets" "github.com/fhmq/hmq/lib/acl" "github.com/fhmq/hmq/pool" + "github.com/fhmq/hmq/sessions" + "github.com/fhmq/hmq/topics" "github.com/shirou/gopsutil/mem" "go.uber.org/zap" "golang.org/x/net/websocket" @@ -42,9 +44,9 @@ type Broker struct { remotes sync.Map nodes map[string]interface{} clusterPool chan *Message - sl *Sublist - rl *RetainList queues map[string]int + topicsMgr *topics.Manager + sessionMgr *sessions.Manager // messagePool []chan *Message } @@ -62,13 +64,24 @@ func NewBroker(config *Config) (*Broker, error) { id: GenUniqueId(), config: config, wpool: pool.New(config.Worker), - sl: NewSublist(), - rl: NewRetainList(), nodes: make(map[string]interface{}), queues: make(map[string]int), clusterPool: make(chan *Message), - // messagePool: newMessagePool(), } + + var err error + b.topicsMgr, err = topics.NewManager("mem") + if err != nil { + log.Error("new topic manager error", zap.Error(err)) + return nil, err + } + + b.sessionMgr, err = sessions.NewManager("mem") + if err != nil { + log.Error("new session manager error", zap.Error(err)) + return nil, err + } + if b.config.TlsPort != "" { tlsconfig, err := NewTLSConfig(b.config.TlsInfo) if err != nil { @@ -333,6 +346,12 @@ func (b *Broker) handleConnection(typ int, conn net.Conn) { c.init() + err = b.getSession(c, msg, connack) + if err != nil { + log.Error("get session error: ", zap.String("clientID", c.info.clientID)) + return + } + cid := c.info.clientID var exist bool @@ -535,9 +554,9 @@ func (b *Broker) SendLocalSubsToRouter(c *client) { b.clients.Range(func(key, value interface{}) bool { client, ok := value.(*client) if ok { - subs := client.subs + subs := client.subMap for _, sub := range subs { - subInfo.Topics = append(subInfo.Topics, string(sub.topic)) + subInfo.Topics = append(subInfo.Topics, sub.topic) subInfo.Qoss = append(subInfo.Qoss, sub.qos) } } @@ -593,15 +612,14 @@ func (b *Broker) removeClient(c *client) { } func (b *Broker) PublishMessage(packet *packets.PublishPacket) { - topic := packet.TopicName - r := b.sl.Match(topic) - if len(r.psubs) == 0 { - return - } + var subs []interface{} + var qoss []byte + b.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &subs, &qoss) - for _, sub := range r.psubs { - if sub != nil { - err := sub.client.WriterPacket(packet) + for _, sub := range subs { + s, ok := sub.(*subscription) + if ok { + err := s.client.WriterPacket(packet) if err != nil { log.Error("process message for psub error, ", zap.Error(err)) } diff --git a/broker/client.go b/broker/client.go index 30d012e..50eb426 100644 --- a/broker/client.go +++ b/broker/client.go @@ -12,6 +12,8 @@ import ( "time" "github.com/eclipse/paho.mqtt.golang/packets" + "github.com/fhmq/hmq/sessions" + "github.com/fhmq/hmq/topics" "go.uber.org/zap" ) @@ -39,11 +41,14 @@ type client struct { info info route route status int - smu sync.RWMutex - subs map[string]*subscription - rsubs map[string]*subInfo ctx context.Context cancelFunc context.CancelFunc + session *sessions.Session + subMap map[string]*subscription + topicsMgr *topics.Manager + subs []interface{} + qoss []byte + rmsgs []*packets.PublishPacket } type subInfo struct { @@ -78,44 +83,12 @@ var ( ) func (c *client) init() { - c.smu.Lock() - defer c.smu.Unlock() c.status = Connected - c.rsubs = make(map[string]*subInfo) - c.subs = make(map[string]*subscription, 10) c.info.localIP = strings.Split(c.conn.LocalAddr().String(), ":")[0] c.info.remoteIP = strings.Split(c.conn.RemoteAddr().String(), ":")[0] c.ctx, c.cancelFunc = context.WithCancel(context.Background()) -} -func (c *client) keepAlive(ch chan int) { - defer close(ch) - - b := c.broker - - keepalive := time.Duration(c.info.keepalive*3/2) * time.Second - timer := time.NewTimer(keepalive) - - for { - select { - case <-ch: - timer.Reset(keepalive) - case <-timer.C: - if c.typ == REMOTE || c.typ == CLUSTER { - timer.Reset(keepalive) - continue - } - log.Error("Client exceeded timeout, disconnecting. ", zap.String("ClientID", c.info.clientID), zap.Uint16("keepalive", c.info.keepalive)) - - msg := &Message{client: c, packet: DisconnectdPacket} - b.SubmitWork(msg) - - timer.Stop() - return - case <-c.ctx.Done(): - return - } - } + c.topicsMgr = c.broker.topicsMgr } func (c *client) readLoop() { @@ -125,14 +98,20 @@ func (c *client) readLoop() { return } - ch := make(chan int, 1000) - go c.keepAlive(ch) + keepAlive := time.Second * time.Duration(c.info.keepalive) + timeOut := keepAlive + (keepAlive / 2) for { select { case <-c.ctx.Done(): return default: + //add read timeout + if err := nc.SetReadDeadline(time.Now().Add(timeOut)); err != nil { + log.Error("set read timeout error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) + return + } + packet, err := packets.ReadPacket(nc) if err != nil { log.Error("read packet error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) @@ -140,8 +119,6 @@ func (c *client) readLoop() { b.SubmitWork(msg) return } - // keepalive channel - ch <- 1 msg := &Message{ client: c, @@ -159,7 +136,6 @@ func ProcessMessage(msg *Message) { if ca == nil { return } - log.Debug("Recv message:", zap.String("message type", reflect.TypeOf(msg.packet).String()[9:]), zap.String("ClientID", c.info.clientID)) switch ca.(type) { case *packets.ConnackPacket: @@ -222,14 +198,6 @@ func (c *client) ProcessPublish(packet *packets.PublishPacket) { log.Error("publish with unknown qos", zap.String("ClientID", c.info.clientID)) return } - if packet.Retain { - if b := c.broker; b != nil { - err := b.rl.Insert(topic, packet) - if err != nil { - log.Error("Insert Retain Message error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } - } } @@ -243,85 +211,41 @@ func (c *client) ProcessPublishMessage(packet *packets.PublishPacket) { return } typ := c.typ - topic := packet.TopicName - r := b.sl.Match(topic) - if r == nil { + if packet.Retain { + if err := c.topicsMgr.Retain(packet); err != nil { + log.Error("Error retaining message: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) + } + } + + err := c.topicsMgr.Subscribers([]byte(packet.TopicName), packet.Qos, &c.subs, &c.qoss) + if err != nil { + log.Error("Error retrieving subscribers list: ", zap.String("ClientID", c.info.clientID)) return } // log.Info("psubs num: ", len(r.psubs)) - if len(r.qsubs) == 0 && len(r.psubs) == 0 { + if len(c.subs) == 0 { return } - for _, sub := range r.psubs { - if sub.client.typ == ROUTER { - if typ != CLIENT { - continue - } - } - if sub != nil { - err := sub.client.WriterPacket(packet) - if err != nil { - log.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } - } - - pre := -1 - now := -1 - t := "$queue/" + topic - cnt, exist := b.queues[t] - if exist { - // log.Info("queue index : ", cnt) - for _, sub := range r.qsubs { - if sub.client.typ == ROUTER { + for _, sub := range c.subs { + s, ok := sub.(*subscription) + if ok { + if s.client.typ == ROUTER { if typ != CLIENT { continue } } - if c.typ == CLIENT { - now = now + 1 - } else { - now = now + sub.client.rsubs[t].num + + err := s.client.WriterPacket(packet) + if err != nil { + log.Error("process message for psub error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) } - if cnt > pre && cnt <= now { - if sub != nil { - err := sub.client.WriterPacket(packet) - if err != nil { - log.Error("send publish error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } - - break - } - pre = now } + } - length := getQueueSubscribeNum(r.qsubs) - if length > 0 { - b.queues[t] = (b.queues[t] + 1) % length - } -} - -func getQueueSubscribeNum(qsubs []*subscription) int { - topic := "$queue/" - if len(qsubs) < 1 { - return 0 - } else { - topic = topic + qsubs[0].topic - } - num := 0 - for _, sub := range qsubs { - if sub.client.typ == CLIENT { - num = num + 1 - } else { - num = num + sub.client.rsubs[topic].num - } - } - return num } func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) { @@ -349,54 +273,24 @@ func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) { continue } - queue := strings.HasPrefix(topic, "$queue/") - if queue { - if len(t) > 7 { - t = t[7:] - if _, exists := b.queues[topic]; !exists { - b.queues[topic] = 0 - } - } else { - retcodes = append(retcodes, QosFailure) - continue - } - } sub := &subscription{ topic: t, qos: qoss[i], client: c, - queue: queue, } - switch c.typ { - case CLIENT: - if _, exist := c.subs[topic]; !exist { - c.subs[topic] = sub - } else { - //if exist ,check whether qos change - c.subs[topic].qos = qoss[i] - retcodes = append(retcodes, qoss[i]) - continue - } - case ROUTER: - if subinfo, exist := c.rsubs[topic]; !exist { - sinfo := &subInfo{sub: sub, num: 1} - c.rsubs[topic] = sinfo - - } else { - subinfo.num = subinfo.num + 1 - retcodes = append(retcodes, qoss[i]) - continue - } - } - err := b.sl.Insert(sub) + rqos, err := c.topicsMgr.Subscribe([]byte(topic), qoss[i], sub) if err != nil { - log.Error("Insert subscription error: ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - retcodes = append(retcodes, QosFailure) - } else { - retcodes = append(retcodes, qoss[i]) + return } + + c.subMap[topic] = sub + c.session.AddTopic(topic, qoss[i]) + retcodes = append(retcodes, rqos) + c.topicsMgr.Retained([]byte(topic), &c.rmsgs) + } + suback.ReturnCodes = retcodes err := c.WriterPacket(suback) @@ -410,16 +304,11 @@ func (c *client) ProcessSubscribe(packet *packets.SubscribePacket) { } //process retain message - for _, t := range topics { - packets := b.rl.Match(t) - if packets == nil { - continue - } - for _, packet := range packets { + for _, rm := range c.rmsgs { + if err := c.WriterPacket(rm); err != nil { + log.Error("Error publishing retained message:", zap.Any("err", err), zap.String("ClientID", c.info.clientID)) + } else { log.Info("process retain message: ", zap.Any("packet", packet), zap.String("ClientID", c.info.clientID)) - if packet != nil { - c.WriterPacket(packet) - } } } } @@ -432,30 +321,16 @@ func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) { if b == nil { return } - typ := c.typ topics := packet.Topics - for _, t := range topics { - - switch typ { - case CLIENT: - sub, ok := c.subs[t] - if ok { - c.unsubscribe(sub) - } - case ROUTER: - subinfo, ok := c.rsubs[t] - if ok { - subinfo.num = subinfo.num - 1 - if subinfo.num < 1 { - delete(c.rsubs, t) - c.unsubscribe(subinfo.sub) - } else { - c.rsubs[t] = subinfo - } - } + for _, topic := range topics { + t := []byte(topic) + sub, exist := c.subMap[topic] + if exist { + c.topicsMgr.Unsubscribe(t, sub) + c.session.RemoveTopic(topic) + delete(c.subMap, topic) } - } unsuback := packets.NewControlPacket(packets.Unsuback).(*packets.UnsubackPacket) @@ -472,19 +347,6 @@ func (c *client) ProcessUnSubscribe(packet *packets.UnsubscribePacket) { } } -func (c *client) unsubscribe(sub *subscription) { - - if c.typ == CLIENT { - delete(c.subs, sub.topic) - - } - b := c.broker - if b != nil && sub != nil { - b.sl.Remove(sub) - } - -} - func (c *client) ProcessPing() { if c.status == Disconnected { return @@ -498,9 +360,7 @@ func (c *client) ProcessPing() { } func (c *client) Close() { - c.smu.Lock() if c.status == Disconnected { - c.smu.Unlock() return } @@ -516,18 +376,11 @@ func (c *client) Close() { c.conn = nil } - c.smu.Unlock() - b := c.broker - subs := c.subs + subs := c.subMap if b != nil { b.removeClient(c) - for _, sub := range subs { - err := b.sl.Remove(sub) - if err != nil { - log.Error("closed client but remove sublist error, ", zap.Error(err), zap.String("ClientID", c.info.clientID)) - } - } + if c.typ == CLIENT { b.BroadcastUnSubscribe(subs) } diff --git a/broker/sesson.go b/broker/sesson.go new file mode 100644 index 0000000..59d5c11 --- /dev/null +++ b/broker/sesson.go @@ -0,0 +1,53 @@ +package broker + +import "github.com/eclipse/paho.mqtt.golang/packets" + +func (b *Broker) getSession(cli *client, req *packets.ConnectPacket, resp *packets.ConnackPacket) error { + // If CleanSession is set to 0, the server MUST resume communications with the + // client based on state from the current session, as identified by the client + // identifier. If there is no session associated with the client identifier the + // server must create a new session. + // + // If CleanSession is set to 1, the client and server must discard any previous + // session and start a new one. b session lasts as long as the network c + // onnection. State data associated with b session must not be reused in any + // subsequent session. + + var err error + + // Check to see if the client supplied an ID, if not, generate one and set + // clean session. + + if len(req.ClientIdentifier) == 0 { + req.CleanSession = true + } + + cid := req.ClientIdentifier + + // If CleanSession is NOT set, check the session store for existing session. + // If found, return it. + if !req.CleanSession { + if cli.session, err = b.sessionMgr.Get(cid); err == nil { + resp.SessionPresent = true + + if err := cli.session.Update(req); err != nil { + return err + } + } + } + + // If CleanSession, or no existing session found, then create a new one + if cli.session == nil { + if cli.session, err = b.sessionMgr.New(cid); err != nil { + return err + } + + resp.SessionPresent = false + + if err := cli.session.Init(req); err != nil { + return err + } + } + + return nil +} diff --git a/main.go b/main.go index f58d31d..0062ff6 100644 --- a/main.go +++ b/main.go @@ -8,10 +8,11 @@ package main import ( "fmt" - "github.com/fhmq/hmq/broker" "os" "os/signal" "runtime" + + "github.com/fhmq/hmq/broker" ) func main() { diff --git a/sessions/memprovider.go b/sessions/memprovider.go new file mode 100644 index 0000000..d82d117 --- /dev/null +++ b/sessions/memprovider.go @@ -0,0 +1,76 @@ +// Copyright (c) 2014 The SurgeMQ Authors. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sessions + +import ( + "fmt" + "sync" +) + +var _ SessionsProvider = (*memProvider)(nil) + +func init() { + Register("mem", NewMemProvider()) +} + +type memProvider struct { + st map[string]*Session + mu sync.RWMutex +} + +func NewMemProvider() *memProvider { + return &memProvider{ + st: make(map[string]*Session), + } +} + +func (this *memProvider) New(id string) (*Session, error) { + this.mu.Lock() + defer this.mu.Unlock() + + this.st[id] = &Session{id: id} + return this.st[id], nil +} + +func (this *memProvider) Get(id string) (*Session, error) { + this.mu.RLock() + defer this.mu.RUnlock() + + sess, ok := this.st[id] + if !ok { + return nil, fmt.Errorf("store/Get: No session found for key %s", id) + } + + return sess, nil +} + +func (this *memProvider) Del(id string) { + this.mu.Lock() + defer this.mu.Unlock() + delete(this.st, id) +} + +func (this *memProvider) Save(id string) error { + return nil +} + +func (this *memProvider) Count() int { + return len(this.st) +} + +func (this *memProvider) Close() error { + this.st = make(map[string]*Session) + return nil +} diff --git a/sessions/redisprovider.go b/sessions/redisprovider.go new file mode 100644 index 0000000..30c701d --- /dev/null +++ b/sessions/redisprovider.go @@ -0,0 +1,95 @@ +package sessions + +import ( + "time" + + log "github.com/cihub/seelog" + "github.com/go-redis/redis" + jsoniter "github.com/json-iterator/go" +) + +var redisClient *redis.Client +var _ SessionsProvider = (*redisProvider)(nil) + +const ( + sessionName = "session" +) + +type redisProvider struct { +} + +func init() { + Register("redis", NewRedisProvider()) +} + +func InitRedisConn(url string) { + redisClient = redis.NewClient(&redis.Options{ + Addr: "127.0.0.1:6379", + Password: "", // no password set + DB: 0, // use default DB + }) + err := redisClient.Ping().Err() + for err != nil { + log.Error("connect redis error: ", err, " 3s try again...") + time.Sleep(3 * time.Second) + err = redisClient.Ping().Err() + } +} + +func NewRedisProvider() *redisProvider { + return &redisProvider{} +} + +func (r *redisProvider) New(id string) (*Session, error) { + val, _ := jsoniter.Marshal(&Session{id: id}) + + err := redisClient.HSet(sessionName, id, val).Err() + if err != nil { + return nil, err + } + + result, err := redisClient.HGet(sessionName, id).Bytes() + if err != nil { + return nil, err + } + + sess := Session{} + err = jsoniter.Unmarshal(result, &sess) + if err != nil { + return nil, err + } + + return &sess, nil +} + +func (r *redisProvider) Get(id string) (*Session, error) { + + result, err := redisClient.HGet(sessionName, id).Bytes() + if err != nil { + return nil, err + } + + sess := Session{} + err = jsoniter.Unmarshal(result, &sess) + if err != nil { + return nil, err + } + + return &sess, nil +} + +func (r *redisProvider) Del(id string) { + redisClient.HDel(sessionName, id) +} + +func (r *redisProvider) Save(id string) error { + return nil +} + +func (r *redisProvider) Count() int { + return int(redisClient.HLen(sessionName).Val()) +} + +func (r *redisProvider) Close() error { + return redisClient.Del(sessionName).Err() +} diff --git a/sessions/session.go b/sessions/session.go new file mode 100644 index 0000000..83b8d29 --- /dev/null +++ b/sessions/session.go @@ -0,0 +1,149 @@ +package sessions + +import ( + "fmt" + "sync" + + "github.com/eclipse/paho.mqtt.golang/packets" +) + +const ( + // Queue size for the ack queue + defaultQueueSize = 16 +) + +type Session struct { + + // cmsg is the CONNECT message + cmsg *packets.ConnectPacket + + // Will message to publish if connect is closed unexpectedly + Will *packets.PublishPacket + + // Retained publish message + Retained *packets.PublishPacket + + // topics stores all the topis for this session/client + topics map[string]byte + + // Initialized? + initted bool + + // Serialize access to this session + mu sync.Mutex + + id string +} + +func (this *Session) Init(msg *packets.ConnectPacket) error { + this.mu.Lock() + defer this.mu.Unlock() + + if this.initted { + return fmt.Errorf("Session already initialized") + } + + this.cmsg = msg + + if this.cmsg.WillFlag { + this.Will = packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) + this.Will.Qos = this.cmsg.Qos + this.Will.TopicName = this.cmsg.WillTopic + this.Will.Payload = this.cmsg.WillMessage + this.Will.Retain = this.cmsg.WillRetain + } + + this.topics = make(map[string]byte, 1) + + this.id = string(msg.ClientIdentifier) + + this.initted = true + + return nil +} + +func (this *Session) Update(msg *packets.ConnectPacket) error { + this.mu.Lock() + defer this.mu.Unlock() + + this.cmsg = msg + return nil +} + +func (this *Session) RetainMessage(msg *packets.PublishPacket) error { + this.mu.Lock() + defer this.mu.Unlock() + + this.Retained = msg + + return nil +} + +func (this *Session) AddTopic(topic string, qos byte) error { + this.mu.Lock() + defer this.mu.Unlock() + + if !this.initted { + return fmt.Errorf("Session not yet initialized") + } + + this.topics[topic] = qos + + return nil +} + +func (this *Session) RemoveTopic(topic string) error { + this.mu.Lock() + defer this.mu.Unlock() + + if !this.initted { + return fmt.Errorf("Session not yet initialized") + } + + delete(this.topics, topic) + + return nil +} + +func (this *Session) Topics() ([]string, []byte, error) { + this.mu.Lock() + defer this.mu.Unlock() + + if !this.initted { + return nil, nil, fmt.Errorf("Session not yet initialized") + } + + var ( + topics []string + qoss []byte + ) + + for k, v := range this.topics { + topics = append(topics, k) + qoss = append(qoss, v) + } + + return topics, qoss, nil +} + +func (this *Session) ID() string { + return this.cmsg.ClientIdentifier +} + +func (this *Session) WillFlag() bool { + this.mu.Lock() + defer this.mu.Unlock() + return this.cmsg.WillFlag +} + +func (this *Session) SetWillFlag(v bool) { + this.mu.Lock() + defer this.mu.Unlock() + this.cmsg.WillFlag = v +} + +func (this *Session) CleanSession() bool { + this.mu.Lock() + defer this.mu.Unlock() + return this.cmsg.CleanSession +} diff --git a/sessions/sessions.go b/sessions/sessions.go new file mode 100644 index 0000000..b160d51 --- /dev/null +++ b/sessions/sessions.go @@ -0,0 +1,92 @@ +package sessions + +import ( + "crypto/rand" + "encoding/base64" + "errors" + "fmt" + "io" +) + +var ( + ErrSessionsProviderNotFound = errors.New("Session: Session provider not found") + ErrKeyNotAvailable = errors.New("Session: not item found for key.") + + providers = make(map[string]SessionsProvider) +) + +type SessionsProvider interface { + New(id string) (*Session, error) + Get(id string) (*Session, error) + Del(id string) + Save(id string) error + Count() int + Close() error +} + +// Register makes a session provider available by the provided name. +// If a Register is called twice with the same name or if the driver is nil, +// it panics. +func Register(name string, provider SessionsProvider) { + if provider == nil { + panic("session: Register provide is nil") + } + + if _, dup := providers[name]; dup { + panic("session: Register called twice for provider " + name) + } + + providers[name] = provider +} + +func Unregister(name string) { + delete(providers, name) +} + +type Manager struct { + p SessionsProvider +} + +func NewManager(providerName string) (*Manager, error) { + p, ok := providers[providerName] + if !ok { + return nil, fmt.Errorf("session: unknown provider %q", providerName) + } + + return &Manager{p: p}, nil +} + +func (this *Manager) New(id string) (*Session, error) { + if id == "" { + id = this.sessionId() + } + return this.p.New(id) +} + +func (this *Manager) Get(id string) (*Session, error) { + return this.p.Get(id) +} + +func (this *Manager) Del(id string) { + this.p.Del(id) +} + +func (this *Manager) Save(id string) error { + return this.p.Save(id) +} + +func (this *Manager) Count() int { + return this.p.Count() +} + +func (this *Manager) Close() error { + return this.p.Close() +} + +func (manager *Manager) sessionId() string { + b := make([]byte, 15) + if _, err := io.ReadFull(rand.Reader, b); err != nil { + return "" + } + return base64.URLEncoding.EncodeToString(b) +} diff --git a/topics/memtopics.go b/topics/memtopics.go new file mode 100644 index 0000000..851c8f1 --- /dev/null +++ b/topics/memtopics.go @@ -0,0 +1,549 @@ +package topics + +import ( + "fmt" + "reflect" + "sync" + + "github.com/eclipse/paho.mqtt.golang/packets" +) + +const ( + QosAtMostOnce byte = iota + QosAtLeastOnce + QosExactlyOnce + QosFailure = 0x80 +) + +var _ TopicsProvider = (*memTopics)(nil) + +type memTopics struct { + // Sub/unsub mutex + smu sync.RWMutex + // Subscription tree + sroot *snode + + // Retained message mutex + rmu sync.RWMutex + // Retained messages topic tree + rroot *rnode +} + +func init() { + Register("mem", NewMemProvider()) +} + +// NewMemProvider returns an new instance of the memTopics, which is implements the +// TopicsProvider interface. memProvider is a hidden struct that stores the topic +// subscriptions and retained messages in memory. The content is not persistend so +// when the server goes, everything will be gone. Use with care. +func NewMemProvider() *memTopics { + return &memTopics{ + sroot: newSNode(), + rroot: newRNode(), + } +} + +func ValidQos(qos byte) bool { + return qos == QosAtMostOnce || qos == QosAtLeastOnce || qos == QosExactlyOnce +} + +func (this *memTopics) Subscribe(topic []byte, qos byte, sub interface{}) (byte, error) { + if !ValidQos(qos) { + return QosFailure, fmt.Errorf("Invalid QoS %d", qos) + } + + if sub == nil { + return QosFailure, fmt.Errorf("Subscriber cannot be nil") + } + + this.smu.Lock() + defer this.smu.Unlock() + + if qos > QosExactlyOnce { + qos = QosExactlyOnce + } + + if err := this.sroot.sinsert(topic, qos, sub); err != nil { + return QosFailure, err + } + + return qos, nil +} + +func (this *memTopics) Unsubscribe(topic []byte, sub interface{}) error { + this.smu.Lock() + defer this.smu.Unlock() + + return this.sroot.sremove(topic, sub) +} + +// Returned values will be invalidated by the next Subscribers call +func (this *memTopics) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { + if !ValidQos(qos) { + return fmt.Errorf("Invalid QoS %d", qos) + } + + this.smu.RLock() + defer this.smu.RUnlock() + + *subs = (*subs)[0:0] + *qoss = (*qoss)[0:0] + + return this.sroot.smatch(topic, qos, subs, qoss) +} + +func (this *memTopics) Retain(msg *packets.PublishPacket) error { + this.rmu.Lock() + defer this.rmu.Unlock() + + // So apparently, at least according to the MQTT Conformance/Interoperability + // Testing, that a payload of 0 means delete the retain message. + // https://eclipse.org/paho/clients/testing/ + if len(msg.Payload) == 0 { + return this.rroot.rremove([]byte(msg.TopicName)) + } + + return this.rroot.rinsert([]byte(msg.TopicName), msg) +} + +func (this *memTopics) Retained(topic []byte, msgs *[]*packets.PublishPacket) error { + this.rmu.RLock() + defer this.rmu.RUnlock() + + return this.rroot.rmatch(topic, msgs) +} + +func (this *memTopics) Close() error { + this.sroot = nil + this.rroot = nil + return nil +} + +// subscrition nodes +type snode struct { + // If this is the end of the topic string, then add subscribers here + subs []interface{} + qos []byte + + // Otherwise add the next topic level here + snodes map[string]*snode +} + +func newSNode() *snode { + return &snode{ + snodes: make(map[string]*snode), + } +} + +func (this *snode) sinsert(topic []byte, qos byte, sub interface{}) error { + // If there's no more topic levels, that means we are at the matching snode + // to insert the subscriber. So let's see if there's such subscriber, + // if so, update it. Otherwise insert it. + if len(topic) == 0 { + // Let's see if the subscriber is already on the list. If yes, update + // QoS and then return. + for i := range this.subs { + if equal(this.subs[i], sub) { + this.qos[i] = qos + return nil + } + } + + // Otherwise add. + this.subs = append(this.subs, sub) + this.qos = append(this.qos, qos) + + return nil + } + + // Not the last level, so let's find or create the next level snode, and + // recursively call it's insert(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Add snode if it doesn't already exist + n, ok := this.snodes[level] + if !ok { + n = newSNode() + this.snodes[level] = n + } + + return n.sinsert(rem, qos, sub) +} + +// This remove implementation ignores the QoS, as long as the subscriber +// matches then it's removed +func (this *snode) sremove(topic []byte, sub interface{}) error { + // If the topic is empty, it means we are at the final matching snode. If so, + // let's find the matching subscribers and remove them. + if len(topic) == 0 { + // If subscriber == nil, then it's signal to remove ALL subscribers + if sub == nil { + this.subs = this.subs[0:0] + this.qos = this.qos[0:0] + return nil + } + + // If we find the subscriber then remove it from the list. Technically + // we just overwrite the slot by shifting all other items up by one. + for i := range this.subs { + if equal(this.subs[i], sub) { + this.subs = append(this.subs[:i], this.subs[i+1:]...) + this.qos = append(this.qos[:i], this.qos[i+1:]...) + return nil + } + } + + return fmt.Errorf("memtopics/remove: No topic found for subscriber") + } + + // Not the last level, so let's find the next level snode, and recursively + // call it's remove(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Find the snode that matches the topic level + n, ok := this.snodes[level] + if !ok { + return fmt.Errorf("memtopics/remove: No topic found") + } + + // Remove the subscriber from the next level snode + if err := n.sremove(rem, sub); err != nil { + return err + } + + // If there are no more subscribers and snodes to the next level we just visited + // let's remove it + if len(n.subs) == 0 && len(n.snodes) == 0 { + delete(this.snodes, level) + } + + return nil +} + +// smatch() returns all the subscribers that are subscribed to the topic. Given a topic +// with no wildcards (publish topic), it returns a list of subscribers that subscribes +// to the topic. For each of the level names, it's a match +// - if there are subscribers to '#', then all the subscribers are added to result set +func (this *snode) smatch(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { + // If the topic is empty, it means we are at the final matching snode. If so, + // let's find the subscribers that match the qos and append them to the list. + if len(topic) == 0 { + this.matchQos(qos, subs, qoss) + return nil + } + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + for k, n := range this.snodes { + // If the key is "#", then these subscribers are added to the result set + if k == MWC { + n.matchQos(qos, subs, qoss) + } else if k == SWC || k == level { + if err := n.smatch(rem, qos, subs, qoss); err != nil { + return err + } + } + } + + return nil +} + +// retained message nodes +type rnode struct { + // If this is the end of the topic string, then add retained messages here + msg *packets.PublishPacket + // Otherwise add the next topic level here + rnodes map[string]*rnode +} + +func newRNode() *rnode { + return &rnode{ + rnodes: make(map[string]*rnode), + } +} + +func (this *rnode) rinsert(topic []byte, msg *packets.PublishPacket) error { + // If there's no more topic levels, that means we are at the matching rnode. + if len(topic) == 0 { + // Reuse the message if possible + if this.msg == nil { + this.msg = msg + } + + return nil + } + + // Not the last level, so let's find or create the next level snode, and + // recursively call it's insert(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Add snode if it doesn't already exist + n, ok := this.rnodes[level] + if !ok { + n = newRNode() + this.rnodes[level] = n + } + + return n.rinsert(rem, msg) +} + +// Remove the retained message for the supplied topic +func (this *rnode) rremove(topic []byte) error { + // If the topic is empty, it means we are at the final matching rnode. If so, + // let's remove the buffer and message. + if len(topic) == 0 { + this.msg = nil + return nil + } + + // Not the last level, so let's find the next level rnode, and recursively + // call it's remove(). + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + // Find the rnode that matches the topic level + n, ok := this.rnodes[level] + if !ok { + return fmt.Errorf("memtopics/rremove: No topic found") + } + + // Remove the subscriber from the next level rnode + if err := n.rremove(rem); err != nil { + return err + } + + // If there are no more rnodes to the next level we just visited let's remove it + if len(n.rnodes) == 0 { + delete(this.rnodes, level) + } + + return nil +} + +// rmatch() finds the retained messages for the topic and qos provided. It's somewhat +// of a reverse match compare to match() since the supplied topic can contain +// wildcards, whereas the retained message topic is a full (no wildcard) topic. +func (this *rnode) rmatch(topic []byte, msgs *[]*packets.PublishPacket) error { + // If the topic is empty, it means we are at the final matching rnode. If so, + // add the retained msg to the list. + if len(topic) == 0 { + if this.msg != nil { + *msgs = append(*msgs, this.msg) + } + return nil + } + + // ntl = next topic level + ntl, rem, err := nextTopicLevel(topic) + if err != nil { + return err + } + + level := string(ntl) + + if level == MWC { + // If '#', add all retained messages starting this node + this.allRetained(msgs) + } else if level == SWC { + // If '+', check all nodes at this level. Next levels must be matched. + for _, n := range this.rnodes { + if err := n.rmatch(rem, msgs); err != nil { + return err + } + } + } else { + // Otherwise, find the matching node, go to the next level + if n, ok := this.rnodes[level]; ok { + if err := n.rmatch(rem, msgs); err != nil { + return err + } + } + } + + return nil +} + +func (this *rnode) allRetained(msgs *[]*packets.PublishPacket) { + if this.msg != nil { + *msgs = append(*msgs, this.msg) + } + + for _, n := range this.rnodes { + n.allRetained(msgs) + } +} + +const ( + stateCHR byte = iota // Regular character + stateMWC // Multi-level wildcard + stateSWC // Single-level wildcard + stateSEP // Topic level separator + stateSYS // System level topic ($) +) + +// Returns topic level, remaining topic levels and any errors +func nextTopicLevel(topic []byte) ([]byte, []byte, error) { + s := stateCHR + + for i, c := range topic { + switch c { + case '/': + if s == stateMWC { + return nil, nil, fmt.Errorf("memtopics/nextTopicLevel: Multi-level wildcard found in topic and it's not at the last level") + } + + if i == 0 { + return []byte(SWC), topic[i+1:], nil + } + + return topic[:i], topic[i+1:], nil + + case '#': + if i != 0 { + return nil, nil, fmt.Errorf("memtopics/nextTopicLevel: Wildcard character '#' must occupy entire topic level") + } + + s = stateMWC + + case '+': + if i != 0 { + return nil, nil, fmt.Errorf("memtopics/nextTopicLevel: Wildcard character '+' must occupy entire topic level") + } + + s = stateSWC + + case '$': + if i == 0 { + return nil, nil, fmt.Errorf("memtopics/nextTopicLevel: Cannot publish to $ topics") + } + + s = stateSYS + + default: + if s == stateMWC || s == stateSWC { + return nil, nil, fmt.Errorf("memtopics/nextTopicLevel: Wildcard characters '#' and '+' must occupy entire topic level") + } + + s = stateCHR + } + } + + // If we got here that means we didn't hit the separator along the way, so the + // topic is either empty, or does not contain a separator. Either way, we return + // the full topic + return topic, nil, nil +} + +// The QoS of the payload messages sent in response to a subscription must be the +// minimum of the QoS of the originally published message (in this case, it's the +// qos parameter) and the maximum QoS granted by the server (in this case, it's +// the QoS in the topic tree). +// +// It's also possible that even if the topic matches, the subscriber is not included +// due to the QoS granted is lower than the published message QoS. For example, +// if the client is granted only QoS 0, and the publish message is QoS 1, then this +// client is not to be send the published message. +func (this *snode) matchQos(qos byte, subs *[]interface{}, qoss *[]byte) { + for _, sub := range this.subs { + // If the published QoS is higher than the subscriber QoS, then we skip the + // subscriber. Otherwise, add to the list. + // if qos >= this.qos[i] { + *subs = append(*subs, sub) + *qoss = append(*qoss, qos) + // } + } +} + +func equal(k1, k2 interface{}) bool { + if reflect.TypeOf(k1) != reflect.TypeOf(k2) { + return false + } + + if reflect.ValueOf(k1).Kind() == reflect.Func { + return &k1 == &k2 + } + + if k1 == k2 { + return true + } + + switch k1 := k1.(type) { + case string: + return k1 == k2.(string) + + case int64: + return k1 == k2.(int64) + + case int32: + return k1 == k2.(int32) + + case int16: + return k1 == k2.(int16) + + case int8: + return k1 == k2.(int8) + + case int: + return k1 == k2.(int) + + case float32: + return k1 == k2.(float32) + + case float64: + return k1 == k2.(float64) + + case uint: + return k1 == k2.(uint) + + case uint8: + return k1 == k2.(uint8) + + case uint16: + return k1 == k2.(uint16) + + case uint32: + return k1 == k2.(uint32) + + case uint64: + return k1 == k2.(uint64) + + case uintptr: + return k1 == k2.(uintptr) + } + + return false +} diff --git a/topics/topics.go b/topics/topics.go new file mode 100644 index 0000000..b99696a --- /dev/null +++ b/topics/topics.go @@ -0,0 +1,91 @@ +package topics + +import ( + "fmt" + + "github.com/eclipse/paho.mqtt.golang/packets" +) + +const ( + // MWC is the multi-level wildcard + MWC = "#" + + // SWC is the single level wildcard + SWC = "+" + + // SEP is the topic level separator + SEP = "/" + + // SYS is the starting character of the system level topics + SYS = "$" + + // Both wildcards + _WC = "#+" +) + +var ( + providers = make(map[string]TopicsProvider) +) + +// TopicsProvider +type TopicsProvider interface { + Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error) + Unsubscribe(topic []byte, subscriber interface{}) error + Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error + Retain(msg *packets.PublishPacket) error + Retained(topic []byte, msgs *[]*packets.PublishPacket) error + Close() error +} + +func Register(name string, provider TopicsProvider) { + if provider == nil { + panic("topics: Register provide is nil") + } + + if _, dup := providers[name]; dup { + panic("topics: Register called twice for provider " + name) + } + + providers[name] = provider +} + +func Unregister(name string) { + delete(providers, name) +} + +type Manager struct { + p TopicsProvider +} + +func NewManager(providerName string) (*Manager, error) { + p, ok := providers[providerName] + if !ok { + return nil, fmt.Errorf("session: unknown provider %q", providerName) + } + + return &Manager{p: p}, nil +} + +func (this *Manager) Subscribe(topic []byte, qos byte, subscriber interface{}) (byte, error) { + return this.p.Subscribe(topic, qos, subscriber) +} + +func (this *Manager) Unsubscribe(topic []byte, subscriber interface{}) error { + return this.p.Unsubscribe(topic, subscriber) +} + +func (this *Manager) Subscribers(topic []byte, qos byte, subs *[]interface{}, qoss *[]byte) error { + return this.p.Subscribers(topic, qos, subs, qoss) +} + +func (this *Manager) Retain(msg *packets.PublishPacket) error { + return this.p.Retain(msg) +} + +func (this *Manager) Retained(topic []byte, msgs *[]*packets.PublishPacket) error { + return this.p.Retained(topic, msgs) +} + +func (this *Manager) Close() error { + return this.p.Close() +}