diff --git a/broker/broker.go b/broker/broker.go index 414f0e6..29e1028 100644 --- a/broker/broker.go +++ b/broker/broker.go @@ -4,6 +4,7 @@ import ( "crypto/tls" "hmq/lib/acl" "hmq/lib/message" + "hmq/packets" "net" "net/http" "sync/atomic" @@ -168,41 +169,44 @@ func (b *Broker) StartListening(typ int) { func (b *Broker) handleConnection(typ int, conn net.Conn, idx uint64) { //process connect packet - buf, err := ReadPacket(conn) + packet, err := packets.ReadPacket(conn) if err != nil { log.Error("read connect packet error: ", err) return } - connMsg, err := DecodeConnectMessage(buf) + if packet == nil { + log.Error("received nil packet") + return + } + msg, ok := packet.(*packets.ConnectPacket) + if !ok { + log.Error("received msg that was not Connect") + return + } + connack := packets.NewControlPacket(packets.Connack).(*packets.ConnackPacket) + connack.ReturnCode = packets.Accepted + connack.SessionPresent = msg.CleanSession + err = connack.Write(conn) if err != nil { - log.Error(err) + log.Error("send connack error, ", err) return } - connack := message.NewConnackMessage() - connack.SetReturnCode(message.ConnectionAccepted) - ack, _ := EncodeMessage(connack) - err1 := WriteBuffer(conn, ack) - if err1 != nil { - log.Error("send connack error, ", err1) - return - } - - willmsg := message.NewPublishMessage() - if connMsg.WillFlag() { - willmsg.SetQoS(connMsg.WillQos()) - willmsg.SetPayload(connMsg.WillMessage()) - willmsg.SetRetain(connMsg.WillRetain()) - willmsg.SetTopic(connMsg.WillTopic()) - willmsg.SetDup(false) + willmsg := packets.NewControlPacket(packets.Publish).(*packets.PublishPacket) + if msg.WillFlag { + willmsg.Qos = msg.WillQos + willmsg.TopicName = msg.WillTopic + willmsg.Retain = msg.WillRetain + willmsg.Payload = msg.WillMessage + willmsg.Dup = msg.Dup } else { willmsg = nil } info := info{ - clientID: connMsg.ClientId(), - username: connMsg.Username(), - password: connMsg.Password(), - keepalive: connMsg.KeepAlive(), + clientID: msg.ClientIdentifier, + username: msg.Username, + password: msg.Password, + keepalive: msg.Keepalive, willMsg: willmsg, } @@ -256,7 +260,7 @@ func (b *Broker) connectRouter(url, remoteID string) { } cid := GenUniqueId() info := info{ - clientID: []byte(cid), + clientID: cid, } c := &client{ typ: REMOTE, diff --git a/broker/client.go b/broker/client.go index 0436941..0230d5e 100644 --- a/broker/client.go +++ b/broker/client.go @@ -3,6 +3,7 @@ package broker import ( "errors" "hmq/lib/message" + "hmq/packets" "net" "strings" "sync" @@ -39,11 +40,11 @@ type subscription struct { } type info struct { - clientID []byte - username []byte + clientID string + username string password []byte keepalive uint16 - willMsg *message.PublishMessage + willMsg packets.ControlPacket localIP string remoteIP string } @@ -66,14 +67,15 @@ func (c *client) readLoop(msgPool *MessagePool) { } msg := &Message{} for { - buf, err := ReadPacket(nc) + packet, err := packets.ReadPacket(conn) + // buf, err := ReadPacket(nc) if err != nil { log.Error("read packet error: ", err) c.Close() return } msg.client = c - msg.msg = buf + msg.msg = packet msgPool.queue <- msg } msgPool.Reduce() @@ -85,45 +87,44 @@ func ProcessMessage(msg *Message) { if c == nil || buf == nil { return } - msgType := uint8(buf[0] & 0xF0 >> 4) - switch msgType { - case CONNACK: + switch m := msg.(type) { + case *packets.Connack: // log.Info("Recv conack message..........") c.ProcessConnAck(buf) - case CONNECT: + case *packets.Connect: // log.Info("Recv connect message..........") c.ProcessConnect(buf) - case PUBLISH: + case *packets.Publish: // log.Info("Recv publish message..........") c.ProcessPublish(buf) - case PUBACK: + case *packets.Puback: //log.Info("Recv publish ack message..........") c.ProcessPubAck(buf) - case PUBCOMP: - //log.Info("Recv publish ack message..........") - c.ProcessPubComp(buf) - case PUBREC: + case *packets.Pubrec:: //log.Info("Recv publish rec message..........") c.ProcessPubREC(buf) - case PUBREL: + case *packets.Pubrel: //log.Info("Recv publish rel message..........") c.ProcessPubREL(buf) - case SUBSCRIBE: + case *packets.Pubcomp: + //log.Info("Recv publish ack message..........") + c.ProcessPubComp(buf) + case *packets.Subscribe: // log.Info("Recv subscribe message.....") c.ProcessSubscribe(buf) - case SUBACK: + case *packets.Suback: // log.Info("Recv suback message.....") - case UNSUBSCRIBE: + case *packets.Unsubscribe: // log.Info("Recv unsubscribe message.....") c.ProcessUnSubscribe(buf) - case UNSUBACK: + case *packets.Unsuback: //log.Info("Recv unsuback message.....") - case PINGREQ: + case *packets.Pingreq: // log.Info("Recv PINGREQ message..........") c.ProcessPing(buf) - case PINGRESP: + case *packets.PingrespPacket: //log.Info("Recv PINGRESP message..........") - case DISCONNECT: + case *packets.Disconnect: // log.Info("Recv DISCONNECT message.......") c.Close() default: @@ -406,7 +407,7 @@ func (c *client) Close() { } } if c.info.willMsg != nil { - b.ProcessPublishMessage(c.info.willMsg) + // b.ProcessPublishMessage(c.info.willMsg) } } if c.conn != nil { diff --git a/broker/info.go b/broker/info.go index 83c8ef5..1981e0f 100644 --- a/broker/info.go +++ b/broker/info.go @@ -39,7 +39,7 @@ func (c *client) SendConnect() { clientID := c.info.clientID connMsg := message.NewConnectMessage() - connMsg.SetClientId(clientID) + connMsg.SetClientId([]byte(clientID)) connMsg.SetVersion(0x04) err := c.writeMessage(connMsg) if err != nil { diff --git a/broker/msgpool.go b/broker/msgpool.go index c672236..d07c27b 100644 --- a/broker/msgpool.go +++ b/broker/msgpool.go @@ -1,6 +1,9 @@ package broker -import "sync" +import ( + "hmq/packets" + "sync" +) const ( MaxUser = 1024 * 1024 @@ -11,7 +14,7 @@ const ( type Message struct { client *client - msg []byte + msg packets.ControlPacket } var ( diff --git a/packets/connack.go b/packets/connack.go new file mode 100644 index 0000000..a512ace --- /dev/null +++ b/packets/connack.go @@ -0,0 +1,51 @@ +package packets + +import ( + "bytes" + "fmt" + "io" +) + +//ConnackPacket is an internal representation of the fields of the +//Connack MQTT packet +type ConnackPacket struct { + FixedHeader + SessionPresent bool + ReturnCode byte +} + +func (ca *ConnackPacket) String() string { + str := fmt.Sprintf("%s", ca.FixedHeader) + str += " " + str += fmt.Sprintf("sessionpresent: %t returncode: %d", ca.SessionPresent, ca.ReturnCode) + return str +} + +func (ca *ConnackPacket) Write(w io.Writer) error { + var body bytes.Buffer + var err error + + body.WriteByte(boolToByte(ca.SessionPresent)) + body.WriteByte(ca.ReturnCode) + ca.FixedHeader.RemainingLength = 2 + packet := ca.FixedHeader.pack() + packet.Write(body.Bytes()) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (ca *ConnackPacket) Unpack(b io.Reader) error { + ca.SessionPresent = 1&decodeByte(b) > 0 + ca.ReturnCode = decodeByte(b) + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (ca *ConnackPacket) Details() Details { + return Details{Qos: 0, MessageID: 0} +} diff --git a/packets/connect.go b/packets/connect.go new file mode 100644 index 0000000..3694464 --- /dev/null +++ b/packets/connect.go @@ -0,0 +1,126 @@ +package packets + +import ( + "bytes" + "fmt" + "io" +) + +//ConnectPacket is an internal representation of the fields of the +//Connect MQTT packet +type ConnectPacket struct { + FixedHeader + ProtocolName string + ProtocolVersion byte + CleanSession bool + WillFlag bool + WillQos byte + WillRetain bool + UsernameFlag bool + PasswordFlag bool + ReservedBit byte + Keepalive uint16 + + ClientIdentifier string + WillTopic string + WillMessage []byte + Username string + Password []byte +} + +func (c *ConnectPacket) String() string { + str := fmt.Sprintf("%s", c.FixedHeader) + str += " " + str += fmt.Sprintf("protocolversion: %d protocolname: %s cleansession: %t willflag: %t WillQos: %d WillRetain: %t Usernameflag: %t Passwordflag: %t keepalive: %d clientId: %s willtopic: %s willmessage: %s Username: %s Password: %s", c.ProtocolVersion, c.ProtocolName, c.CleanSession, c.WillFlag, c.WillQos, c.WillRetain, c.UsernameFlag, c.PasswordFlag, c.Keepalive, c.ClientIdentifier, c.WillTopic, c.WillMessage, c.Username, c.Password) + return str +} + +func (c *ConnectPacket) Write(w io.Writer) error { + var body bytes.Buffer + var err error + + body.Write(encodeString(c.ProtocolName)) + body.WriteByte(c.ProtocolVersion) + body.WriteByte(boolToByte(c.CleanSession)<<1 | boolToByte(c.WillFlag)<<2 | c.WillQos<<3 | boolToByte(c.WillRetain)<<5 | boolToByte(c.PasswordFlag)<<6 | boolToByte(c.UsernameFlag)<<7) + body.Write(encodeUint16(c.Keepalive)) + body.Write(encodeString(c.ClientIdentifier)) + if c.WillFlag { + body.Write(encodeString(c.WillTopic)) + body.Write(encodeBytes(c.WillMessage)) + } + if c.UsernameFlag { + body.Write(encodeString(c.Username)) + } + if c.PasswordFlag { + body.Write(encodeBytes(c.Password)) + } + c.FixedHeader.RemainingLength = body.Len() + packet := c.FixedHeader.pack() + packet.Write(body.Bytes()) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (c *ConnectPacket) Unpack(b io.Reader) error { + c.ProtocolName = decodeString(b) + c.ProtocolVersion = decodeByte(b) + options := decodeByte(b) + c.ReservedBit = 1 & options + c.CleanSession = 1&(options>>1) > 0 + c.WillFlag = 1&(options>>2) > 0 + c.WillQos = 3 & (options >> 3) + c.WillRetain = 1&(options>>5) > 0 + c.PasswordFlag = 1&(options>>6) > 0 + c.UsernameFlag = 1&(options>>7) > 0 + c.Keepalive = decodeUint16(b) + c.ClientIdentifier = decodeString(b) + if c.WillFlag { + c.WillTopic = decodeString(b) + c.WillMessage = decodeBytes(b) + } + if c.UsernameFlag { + c.Username = decodeString(b) + } + if c.PasswordFlag { + c.Password = decodeBytes(b) + } + + return nil +} + +//Validate performs validation of the fields of a Connect packet +func (c *ConnectPacket) Validate() byte { + if c.PasswordFlag && !c.UsernameFlag { + return ErrRefusedBadUsernameOrPassword + } + if c.ReservedBit != 0 { + //Bad reserved bit + return ErrProtocolViolation + } + if (c.ProtocolName == "MQIsdp" && c.ProtocolVersion != 3) || (c.ProtocolName == "MQTT" && c.ProtocolVersion != 4) { + //Mismatched or unsupported protocol version + return ErrRefusedBadProtocolVersion + } + if c.ProtocolName != "MQIsdp" && c.ProtocolName != "MQTT" { + //Bad protocol name + return ErrProtocolViolation + } + if len(c.ClientIdentifier) > 65535 || len(c.Username) > 65535 || len(c.Password) > 65535 { + //Bad size field + return ErrProtocolViolation + } + if len(c.ClientIdentifier) == 0 && !c.CleanSession { + //Bad client identifier + return ErrRefusedIDRejected + } + return Accepted +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (c *ConnectPacket) Details() Details { + return Details{Qos: 0, MessageID: 0} +} diff --git a/packets/disconnect.go b/packets/disconnect.go new file mode 100644 index 0000000..e5c1869 --- /dev/null +++ b/packets/disconnect.go @@ -0,0 +1,36 @@ +package packets + +import ( + "fmt" + "io" +) + +//DisconnectPacket is an internal representation of the fields of the +//Disconnect MQTT packet +type DisconnectPacket struct { + FixedHeader +} + +func (d *DisconnectPacket) String() string { + str := fmt.Sprintf("%s", d.FixedHeader) + return str +} + +func (d *DisconnectPacket) Write(w io.Writer) error { + packet := d.FixedHeader.pack() + _, err := packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (d *DisconnectPacket) Unpack(b io.Reader) error { + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (d *DisconnectPacket) Details() Details { + return Details{Qos: 0, MessageID: 0} +} diff --git a/packets/packets.go b/packets/packets.go new file mode 100644 index 0000000..cbc194a --- /dev/null +++ b/packets/packets.go @@ -0,0 +1,322 @@ +package packets + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" +) + +//ControlPacket defines the interface for structs intended to hold +//decoded MQTT packets, either from being read or before being +//written +type ControlPacket interface { + Write(io.Writer) error + Unpack(io.Reader) error + String() string + Details() Details +} + +//PacketNames maps the constants for each of the MQTT packet types +//to a string representation of their name. +var PacketNames = map[uint8]string{ + 1: "CONNECT", + 2: "CONNACK", + 3: "PUBLISH", + 4: "PUBACK", + 5: "PUBREC", + 6: "PUBREL", + 7: "PUBCOMP", + 8: "SUBSCRIBE", + 9: "SUBACK", + 10: "UNSUBSCRIBE", + 11: "UNSUBACK", + 12: "PINGREQ", + 13: "PINGRESP", + 14: "DISCONNECT", +} + +//Below are the constants assigned to each of the MQTT packet types +const ( + Connect = 1 + Connack = 2 + Publish = 3 + Puback = 4 + Pubrec = 5 + Pubrel = 6 + Pubcomp = 7 + Subscribe = 8 + Suback = 9 + Unsubscribe = 10 + Unsuback = 11 + Pingreq = 12 + Pingresp = 13 + Disconnect = 14 +) + +//Below are the const definitions for error codes returned by +//Connect() +const ( + Accepted = 0x00 + ErrRefusedBadProtocolVersion = 0x01 + ErrRefusedIDRejected = 0x02 + ErrRefusedServerUnavailable = 0x03 + ErrRefusedBadUsernameOrPassword = 0x04 + ErrRefusedNotAuthorised = 0x05 + ErrNetworkError = 0xFE + ErrProtocolViolation = 0xFF +) + +//ConnackReturnCodes is a map of the error codes constants for Connect() +//to a string representation of the error +var ConnackReturnCodes = map[uint8]string{ + 0: "Connection Accepted", + 1: "Connection Refused: Bad Protocol Version", + 2: "Connection Refused: Client Identifier Rejected", + 3: "Connection Refused: Server Unavailable", + 4: "Connection Refused: Username or Password in unknown format", + 5: "Connection Refused: Not Authorised", + 254: "Connection Error", + 255: "Connection Refused: Protocol Violation", +} + +//ConnErrors is a map of the errors codes constants for Connect() +//to a Go error +var ConnErrors = map[byte]error{ + Accepted: nil, + ErrRefusedBadProtocolVersion: errors.New("Unnacceptable protocol version"), + ErrRefusedIDRejected: errors.New("Identifier rejected"), + ErrRefusedServerUnavailable: errors.New("Server Unavailable"), + ErrRefusedBadUsernameOrPassword: errors.New("Bad user name or password"), + ErrRefusedNotAuthorised: errors.New("Not Authorized"), + ErrNetworkError: errors.New("Network Error"), + ErrProtocolViolation: errors.New("Protocol Violation"), +} + +//ReadPacket takes an instance of an io.Reader (such as net.Conn) and attempts +//to read an MQTT packet from the stream. It returns a ControlPacket +//representing the decoded MQTT packet and an error. One of these returns will +//always be nil, a nil ControlPacket indicating an error occurred. +func ReadPacket(r io.Reader) (cp ControlPacket, err error) { + var fh FixedHeader + b := make([]byte, 1) + + _, err = io.ReadFull(r, b) + if err != nil { + return nil, err + } + fh.unpack(b[0], r) + cp = NewControlPacketWithHeader(fh) + if cp == nil { + return nil, errors.New("Bad data from client") + } + packetBytes := make([]byte, fh.RemainingLength) + _, err = io.ReadFull(r, packetBytes) + if err != nil { + return nil, err + } + err = cp.Unpack(bytes.NewBuffer(packetBytes)) + return cp, err +} + +//NewControlPacket is used to create a new ControlPacket of the type specified +//by packetType, this is usually done by reference to the packet type constants +//defined in packets.go. The newly created ControlPacket is empty and a pointer +//is returned. +func NewControlPacket(packetType byte) (cp ControlPacket) { + switch packetType { + case Connect: + cp = &ConnectPacket{FixedHeader: FixedHeader{MessageType: Connect}} + case Connack: + cp = &ConnackPacket{FixedHeader: FixedHeader{MessageType: Connack}} + case Disconnect: + cp = &DisconnectPacket{FixedHeader: FixedHeader{MessageType: Disconnect}} + case Publish: + cp = &PublishPacket{FixedHeader: FixedHeader{MessageType: Publish}} + case Puback: + cp = &PubackPacket{FixedHeader: FixedHeader{MessageType: Puback}} + case Pubrec: + cp = &PubrecPacket{FixedHeader: FixedHeader{MessageType: Pubrec}} + case Pubrel: + cp = &PubrelPacket{FixedHeader: FixedHeader{MessageType: Pubrel, Qos: 1}} + case Pubcomp: + cp = &PubcompPacket{FixedHeader: FixedHeader{MessageType: Pubcomp}} + case Subscribe: + cp = &SubscribePacket{FixedHeader: FixedHeader{MessageType: Subscribe, Qos: 1}} + case Suback: + cp = &SubackPacket{FixedHeader: FixedHeader{MessageType: Suback}} + case Unsubscribe: + cp = &UnsubscribePacket{FixedHeader: FixedHeader{MessageType: Unsubscribe, Qos: 1}} + case Unsuback: + cp = &UnsubackPacket{FixedHeader: FixedHeader{MessageType: Unsuback}} + case Pingreq: + cp = &PingreqPacket{FixedHeader: FixedHeader{MessageType: Pingreq}} + case Pingresp: + cp = &PingrespPacket{FixedHeader: FixedHeader{MessageType: Pingresp}} + default: + return nil + } + return cp +} + +//NewControlPacketWithHeader is used to create a new ControlPacket of the type +//specified within the FixedHeader that is passed to the function. +//The newly created ControlPacket is empty and a pointer is returned. +func NewControlPacketWithHeader(fh FixedHeader) (cp ControlPacket) { + switch fh.MessageType { + case Connect: + cp = &ConnectPacket{FixedHeader: fh} + case Connack: + cp = &ConnackPacket{FixedHeader: fh} + case Disconnect: + cp = &DisconnectPacket{FixedHeader: fh} + case Publish: + cp = &PublishPacket{FixedHeader: fh} + case Puback: + cp = &PubackPacket{FixedHeader: fh} + case Pubrec: + cp = &PubrecPacket{FixedHeader: fh} + case Pubrel: + cp = &PubrelPacket{FixedHeader: fh} + case Pubcomp: + cp = &PubcompPacket{FixedHeader: fh} + case Subscribe: + cp = &SubscribePacket{FixedHeader: fh} + case Suback: + cp = &SubackPacket{FixedHeader: fh} + case Unsubscribe: + cp = &UnsubscribePacket{FixedHeader: fh} + case Unsuback: + cp = &UnsubackPacket{FixedHeader: fh} + case Pingreq: + cp = &PingreqPacket{FixedHeader: fh} + case Pingresp: + cp = &PingrespPacket{FixedHeader: fh} + default: + return nil + } + return cp +} + +//Details struct returned by the Details() function called on +//ControlPackets to present details of the Qos and MessageID +//of the ControlPacket +type Details struct { + Qos byte + MessageID uint16 +} + +//FixedHeader is a struct to hold the decoded information from +//the fixed header of an MQTT ControlPacket +type FixedHeader struct { + MessageType byte + Dup bool + Qos byte + Retain bool + RemainingLength int +} + +func (fh FixedHeader) String() string { + return fmt.Sprintf("%s: dup: %t qos: %d retain: %t rLength: %d", PacketNames[fh.MessageType], fh.Dup, fh.Qos, fh.Retain, fh.RemainingLength) +} + +func boolToByte(b bool) byte { + switch b { + case true: + return 1 + default: + return 0 + } +} + +func (fh *FixedHeader) pack() bytes.Buffer { + var header bytes.Buffer + header.WriteByte(fh.MessageType<<4 | boolToByte(fh.Dup)<<3 | fh.Qos<<1 | boolToByte(fh.Retain)) + header.Write(encodeLength(fh.RemainingLength)) + return header +} + +func (fh *FixedHeader) unpack(typeAndFlags byte, r io.Reader) { + fh.MessageType = typeAndFlags >> 4 + fh.Dup = (typeAndFlags>>3)&0x01 > 0 + fh.Qos = (typeAndFlags >> 1) & 0x03 + fh.Retain = typeAndFlags&0x01 > 0 + fh.RemainingLength = decodeLength(r) +} + +func decodeByte(b io.Reader) byte { + num := make([]byte, 1) + b.Read(num) + return num[0] +} + +func decodeUint16(b io.Reader) uint16 { + num := make([]byte, 2) + b.Read(num) + return binary.BigEndian.Uint16(num) +} + +func encodeUint16(num uint16) []byte { + bytes := make([]byte, 2) + binary.BigEndian.PutUint16(bytes, num) + return bytes +} + +func encodeString(field string) []byte { + fieldLength := make([]byte, 2) + binary.BigEndian.PutUint16(fieldLength, uint16(len(field))) + return append(fieldLength, []byte(field)...) +} + +func decodeString(b io.Reader) string { + fieldLength := decodeUint16(b) + field := make([]byte, fieldLength) + b.Read(field) + return string(field) +} + +func decodeBytes(b io.Reader) []byte { + fieldLength := decodeUint16(b) + field := make([]byte, fieldLength) + b.Read(field) + return field +} + +func encodeBytes(field []byte) []byte { + fieldLength := make([]byte, 2) + binary.BigEndian.PutUint16(fieldLength, uint16(len(field))) + return append(fieldLength, field...) +} + +func encodeLength(length int) []byte { + var encLength []byte + for { + digit := byte(length % 128) + length /= 128 + if length > 0 { + digit |= 0x80 + } + encLength = append(encLength, digit) + if length == 0 { + break + } + } + return encLength +} + +func decodeLength(r io.Reader) int { + var rLength uint32 + var multiplier uint32 + b := make([]byte, 1) + for multiplier < 27 { //fix: Infinite '(digit & 128) == 1' will cause the dead loop + io.ReadFull(r, b) + digit := b[0] + rLength |= uint32(digit&127) << multiplier + if (digit & 128) == 0 { + break + } + multiplier += 7 + } + return int(rLength) +} diff --git a/packets/packets_test.go b/packets/packets_test.go new file mode 100644 index 0000000..51d887d --- /dev/null +++ b/packets/packets_test.go @@ -0,0 +1,159 @@ +package packets + +import ( + "bytes" + "testing" +) + +func TestPacketNames(t *testing.T) { + if PacketNames[1] != "CONNECT" { + t.Errorf("PacketNames[1] is %s, should be %s", PacketNames[1], "CONNECT") + } + if PacketNames[2] != "CONNACK" { + t.Errorf("PacketNames[2] is %s, should be %s", PacketNames[2], "CONNACK") + } + if PacketNames[3] != "PUBLISH" { + t.Errorf("PacketNames[3] is %s, should be %s", PacketNames[3], "PUBLISH") + } + if PacketNames[4] != "PUBACK" { + t.Errorf("PacketNames[4] is %s, should be %s", PacketNames[4], "PUBACK") + } + if PacketNames[5] != "PUBREC" { + t.Errorf("PacketNames[5] is %s, should be %s", PacketNames[5], "PUBREC") + } + if PacketNames[6] != "PUBREL" { + t.Errorf("PacketNames[6] is %s, should be %s", PacketNames[6], "PUBREL") + } + if PacketNames[7] != "PUBCOMP" { + t.Errorf("PacketNames[7] is %s, should be %s", PacketNames[7], "PUBCOMP") + } + if PacketNames[8] != "SUBSCRIBE" { + t.Errorf("PacketNames[8] is %s, should be %s", PacketNames[8], "SUBSCRIBE") + } + if PacketNames[9] != "SUBACK" { + t.Errorf("PacketNames[9] is %s, should be %s", PacketNames[9], "SUBACK") + } + if PacketNames[10] != "UNSUBSCRIBE" { + t.Errorf("PacketNames[10] is %s, should be %s", PacketNames[10], "UNSUBSCRIBE") + } + if PacketNames[11] != "UNSUBACK" { + t.Errorf("PacketNames[11] is %s, should be %s", PacketNames[11], "UNSUBACK") + } + if PacketNames[12] != "PINGREQ" { + t.Errorf("PacketNames[12] is %s, should be %s", PacketNames[12], "PINGREQ") + } + if PacketNames[13] != "PINGRESP" { + t.Errorf("PacketNames[13] is %s, should be %s", PacketNames[13], "PINGRESP") + } + if PacketNames[14] != "DISCONNECT" { + t.Errorf("PacketNames[14] is %s, should be %s", PacketNames[14], "DISCONNECT") + } +} + +func TestPacketConsts(t *testing.T) { + if Connect != 1 { + t.Errorf("Const for Connect is %d, should be %d", Connect, 1) + } + if Connack != 2 { + t.Errorf("Const for Connack is %d, should be %d", Connack, 2) + } + if Publish != 3 { + t.Errorf("Const for Publish is %d, should be %d", Publish, 3) + } + if Puback != 4 { + t.Errorf("Const for Puback is %d, should be %d", Puback, 4) + } + if Pubrec != 5 { + t.Errorf("Const for Pubrec is %d, should be %d", Pubrec, 5) + } + if Pubrel != 6 { + t.Errorf("Const for Pubrel is %d, should be %d", Pubrel, 6) + } + if Pubcomp != 7 { + t.Errorf("Const for Pubcomp is %d, should be %d", Pubcomp, 7) + } + if Subscribe != 8 { + t.Errorf("Const for Subscribe is %d, should be %d", Subscribe, 8) + } + if Suback != 9 { + t.Errorf("Const for Suback is %d, should be %d", Suback, 9) + } + if Unsubscribe != 10 { + t.Errorf("Const for Unsubscribe is %d, should be %d", Unsubscribe, 10) + } + if Unsuback != 11 { + t.Errorf("Const for Unsuback is %d, should be %d", Unsuback, 11) + } + if Pingreq != 12 { + t.Errorf("Const for Pingreq is %d, should be %d", Pingreq, 12) + } + if Pingresp != 13 { + t.Errorf("Const for Pingresp is %d, should be %d", Pingresp, 13) + } + if Disconnect != 14 { + t.Errorf("Const for Disconnect is %d, should be %d", Disconnect, 14) + } +} + +func TestConnackConsts(t *testing.T) { + if Accepted != 0x00 { + t.Errorf("Const for Accepted is %d, should be %d", Accepted, 0) + } + if ErrRefusedBadProtocolVersion != 0x01 { + t.Errorf("Const for RefusedBadProtocolVersion is %d, should be %d", ErrRefusedBadProtocolVersion, 1) + } + if ErrRefusedIDRejected != 0x02 { + t.Errorf("Const for RefusedIDRejected is %d, should be %d", ErrRefusedIDRejected, 2) + } + if ErrRefusedServerUnavailable != 0x03 { + t.Errorf("Const for RefusedServerUnavailable is %d, should be %d", ErrRefusedServerUnavailable, 3) + } + if ErrRefusedBadUsernameOrPassword != 0x04 { + t.Errorf("Const for RefusedBadUsernameOrPassword is %d, should be %d", ErrRefusedBadUsernameOrPassword, 4) + } + if ErrRefusedNotAuthorised != 0x05 { + t.Errorf("Const for RefusedNotAuthorised is %d, should be %d", ErrRefusedNotAuthorised, 5) + } +} + +func TestConnectPacket(t *testing.T) { + connectPacketBytes := bytes.NewBuffer([]byte{16, 52, 0, 4, 77, 81, 84, 84, 4, 204, 0, 0, 0, 0, 0, 4, 116, 101, 115, 116, 0, 12, 84, 101, 115, 116, 32, 80, 97, 121, 108, 111, 97, 100, 0, 8, 116, 101, 115, 116, 117, 115, 101, 114, 0, 8, 116, 101, 115, 116, 112, 97, 115, 115}) + packet, err := ReadPacket(connectPacketBytes) + if err != nil { + t.Fatalf("Error reading packet: %s", err.Error()) + } + cp := packet.(*ConnectPacket) + if cp.ProtocolName != "MQTT" { + t.Errorf("Connect Packet ProtocolName is %s, should be %s", cp.ProtocolName, "MQTT") + } + if cp.ProtocolVersion != 4 { + t.Errorf("Connect Packet ProtocolVersion is %d, should be %d", cp.ProtocolVersion, 4) + } + if cp.UsernameFlag != true { + t.Errorf("Connect Packet UsernameFlag is %t, should be %t", cp.UsernameFlag, true) + } + if cp.Username != "testuser" { + t.Errorf("Connect Packet Username is %s, should be %s", cp.Username, "testuser") + } + if cp.PasswordFlag != true { + t.Errorf("Connect Packet PasswordFlag is %t, should be %t", cp.PasswordFlag, true) + } + if string(cp.Password) != "testpass" { + t.Errorf("Connect Packet Password is %s, should be %s", string(cp.Password), "testpass") + } + if cp.WillFlag != true { + t.Errorf("Connect Packet WillFlag is %t, should be %t", cp.WillFlag, true) + } + if cp.WillTopic != "test" { + t.Errorf("Connect Packet WillTopic is %s, should be %s", cp.WillTopic, "test") + } + if cp.WillQos != 1 { + t.Errorf("Connect Packet WillQos is %d, should be %d", cp.WillQos, 1) + } + if cp.WillRetain != false { + t.Errorf("Connect Packet WillRetain is %t, should be %t", cp.WillRetain, false) + } + if string(cp.WillMessage) != "Test Payload" { + t.Errorf("Connect Packet WillMessage is %s, should be %s", string(cp.WillMessage), "Test Payload") + } +} diff --git a/packets/pingreq.go b/packets/pingreq.go new file mode 100644 index 0000000..5c3e88f --- /dev/null +++ b/packets/pingreq.go @@ -0,0 +1,36 @@ +package packets + +import ( + "fmt" + "io" +) + +//PingreqPacket is an internal representation of the fields of the +//Pingreq MQTT packet +type PingreqPacket struct { + FixedHeader +} + +func (pr *PingreqPacket) String() string { + str := fmt.Sprintf("%s", pr.FixedHeader) + return str +} + +func (pr *PingreqPacket) Write(w io.Writer) error { + packet := pr.FixedHeader.pack() + _, err := packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (pr *PingreqPacket) Unpack(b io.Reader) error { + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (pr *PingreqPacket) Details() Details { + return Details{Qos: 0, MessageID: 0} +} diff --git a/packets/pingresp.go b/packets/pingresp.go new file mode 100644 index 0000000..39ebc00 --- /dev/null +++ b/packets/pingresp.go @@ -0,0 +1,36 @@ +package packets + +import ( + "fmt" + "io" +) + +//PingrespPacket is an internal representation of the fields of the +//Pingresp MQTT packet +type PingrespPacket struct { + FixedHeader +} + +func (pr *PingrespPacket) String() string { + str := fmt.Sprintf("%s", pr.FixedHeader) + return str +} + +func (pr *PingrespPacket) Write(w io.Writer) error { + packet := pr.FixedHeader.pack() + _, err := packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (pr *PingrespPacket) Unpack(b io.Reader) error { + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (pr *PingrespPacket) Details() Details { + return Details{Qos: 0, MessageID: 0} +} diff --git a/packets/puback.go b/packets/puback.go new file mode 100644 index 0000000..e30402c --- /dev/null +++ b/packets/puback.go @@ -0,0 +1,44 @@ +package packets + +import ( + "fmt" + "io" +) + +//PubackPacket is an internal representation of the fields of the +//Puback MQTT packet +type PubackPacket struct { + FixedHeader + MessageID uint16 +} + +func (pa *PubackPacket) String() string { + str := fmt.Sprintf("%s", pa.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d", pa.MessageID) + return str +} + +func (pa *PubackPacket) Write(w io.Writer) error { + var err error + pa.FixedHeader.RemainingLength = 2 + packet := pa.FixedHeader.pack() + packet.Write(encodeUint16(pa.MessageID)) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (pa *PubackPacket) Unpack(b io.Reader) error { + pa.MessageID = decodeUint16(b) + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (pa *PubackPacket) Details() Details { + return Details{Qos: pa.Qos, MessageID: pa.MessageID} +} diff --git a/packets/pubcomp.go b/packets/pubcomp.go new file mode 100644 index 0000000..fb994ae --- /dev/null +++ b/packets/pubcomp.go @@ -0,0 +1,44 @@ +package packets + +import ( + "fmt" + "io" +) + +//PubcompPacket is an internal representation of the fields of the +//Pubcomp MQTT packet +type PubcompPacket struct { + FixedHeader + MessageID uint16 +} + +func (pc *PubcompPacket) String() string { + str := fmt.Sprintf("%s", pc.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d", pc.MessageID) + return str +} + +func (pc *PubcompPacket) Write(w io.Writer) error { + var err error + pc.FixedHeader.RemainingLength = 2 + packet := pc.FixedHeader.pack() + packet.Write(encodeUint16(pc.MessageID)) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (pc *PubcompPacket) Unpack(b io.Reader) error { + pc.MessageID = decodeUint16(b) + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (pc *PubcompPacket) Details() Details { + return Details{Qos: pc.Qos, MessageID: pc.MessageID} +} diff --git a/packets/publish.go b/packets/publish.go new file mode 100644 index 0000000..b660ef4 --- /dev/null +++ b/packets/publish.go @@ -0,0 +1,80 @@ +package packets + +import ( + "bytes" + "fmt" + "io" +) + +//PublishPacket is an internal representation of the fields of the +//Publish MQTT packet +type PublishPacket struct { + FixedHeader + TopicName string + MessageID uint16 + Payload []byte +} + +func (p *PublishPacket) String() string { + str := fmt.Sprintf("%s", p.FixedHeader) + str += " " + str += fmt.Sprintf("topicName: %s MessageID: %d", p.TopicName, p.MessageID) + str += " " + str += fmt.Sprintf("payload: %s", string(p.Payload)) + return str +} + +func (p *PublishPacket) Write(w io.Writer) error { + var body bytes.Buffer + var err error + + body.Write(encodeString(p.TopicName)) + if p.Qos > 0 { + body.Write(encodeUint16(p.MessageID)) + } + p.FixedHeader.RemainingLength = body.Len() + len(p.Payload) + packet := p.FixedHeader.pack() + packet.Write(body.Bytes()) + packet.Write(p.Payload) + _, err = w.Write(packet.Bytes()) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (p *PublishPacket) Unpack(b io.Reader) error { + var payloadLength = p.FixedHeader.RemainingLength + p.TopicName = decodeString(b) + if p.Qos > 0 { + p.MessageID = decodeUint16(b) + payloadLength -= len(p.TopicName) + 4 + } else { + payloadLength -= len(p.TopicName) + 2 + } + if payloadLength < 0 { + return fmt.Errorf("Error upacking publish, payload length < 0") + } + p.Payload = make([]byte, payloadLength) + _, err := b.Read(p.Payload) + + return err +} + +//Copy creates a new PublishPacket with the same topic and payload +//but an empty fixed header, useful for when you want to deliver +//a message with different properties such as Qos but the same +//content +func (p *PublishPacket) Copy() *PublishPacket { + newP := NewControlPacket(Publish).(*PublishPacket) + newP.TopicName = p.TopicName + newP.Payload = p.Payload + + return newP +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (p *PublishPacket) Details() Details { + return Details{Qos: p.Qos, MessageID: p.MessageID} +} diff --git a/packets/pubrec.go b/packets/pubrec.go new file mode 100644 index 0000000..9874e64 --- /dev/null +++ b/packets/pubrec.go @@ -0,0 +1,44 @@ +package packets + +import ( + "fmt" + "io" +) + +//PubrecPacket is an internal representation of the fields of the +//Pubrec MQTT packet +type PubrecPacket struct { + FixedHeader + MessageID uint16 +} + +func (pr *PubrecPacket) String() string { + str := fmt.Sprintf("%s", pr.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d", pr.MessageID) + return str +} + +func (pr *PubrecPacket) Write(w io.Writer) error { + var err error + pr.FixedHeader.RemainingLength = 2 + packet := pr.FixedHeader.pack() + packet.Write(encodeUint16(pr.MessageID)) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (pr *PubrecPacket) Unpack(b io.Reader) error { + pr.MessageID = decodeUint16(b) + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (pr *PubrecPacket) Details() Details { + return Details{Qos: pr.Qos, MessageID: pr.MessageID} +} diff --git a/packets/pubrel.go b/packets/pubrel.go new file mode 100644 index 0000000..a7ecce7 --- /dev/null +++ b/packets/pubrel.go @@ -0,0 +1,44 @@ +package packets + +import ( + "fmt" + "io" +) + +//PubrelPacket is an internal representation of the fields of the +//Pubrel MQTT packet +type PubrelPacket struct { + FixedHeader + MessageID uint16 +} + +func (pr *PubrelPacket) String() string { + str := fmt.Sprintf("%s", pr.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d", pr.MessageID) + return str +} + +func (pr *PubrelPacket) Write(w io.Writer) error { + var err error + pr.FixedHeader.RemainingLength = 2 + packet := pr.FixedHeader.pack() + packet.Write(encodeUint16(pr.MessageID)) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (pr *PubrelPacket) Unpack(b io.Reader) error { + pr.MessageID = decodeUint16(b) + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (pr *PubrelPacket) Details() Details { + return Details{Qos: pr.Qos, MessageID: pr.MessageID} +} diff --git a/packets/suback.go b/packets/suback.go new file mode 100644 index 0000000..557a7db --- /dev/null +++ b/packets/suback.go @@ -0,0 +1,52 @@ +package packets + +import ( + "bytes" + "fmt" + "io" +) + +//SubackPacket is an internal representation of the fields of the +//Suback MQTT packet +type SubackPacket struct { + FixedHeader + MessageID uint16 + ReturnCodes []byte +} + +func (sa *SubackPacket) String() string { + str := fmt.Sprintf("%s", sa.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d", sa.MessageID) + return str +} + +func (sa *SubackPacket) Write(w io.Writer) error { + var body bytes.Buffer + var err error + body.Write(encodeUint16(sa.MessageID)) + body.Write(sa.ReturnCodes) + sa.FixedHeader.RemainingLength = body.Len() + packet := sa.FixedHeader.pack() + packet.Write(body.Bytes()) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (sa *SubackPacket) Unpack(b io.Reader) error { + var qosBuffer bytes.Buffer + sa.MessageID = decodeUint16(b) + qosBuffer.ReadFrom(b) + sa.ReturnCodes = qosBuffer.Bytes() + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (sa *SubackPacket) Details() Details { + return Details{Qos: 0, MessageID: sa.MessageID} +} diff --git a/packets/subscribe.go b/packets/subscribe.go new file mode 100644 index 0000000..c418ef0 --- /dev/null +++ b/packets/subscribe.go @@ -0,0 +1,62 @@ +package packets + +import ( + "bytes" + "fmt" + "io" +) + +//SubscribePacket is an internal representation of the fields of the +//Subscribe MQTT packet +type SubscribePacket struct { + FixedHeader + MessageID uint16 + Topics []string + Qoss []byte +} + +func (s *SubscribePacket) String() string { + str := fmt.Sprintf("%s", s.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d topics: %s", s.MessageID, s.Topics) + return str +} + +func (s *SubscribePacket) Write(w io.Writer) error { + var body bytes.Buffer + var err error + + body.Write(encodeUint16(s.MessageID)) + for i, topic := range s.Topics { + body.Write(encodeString(topic)) + body.WriteByte(s.Qoss[i]) + } + s.FixedHeader.RemainingLength = body.Len() + packet := s.FixedHeader.pack() + packet.Write(body.Bytes()) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (s *SubscribePacket) Unpack(b io.Reader) error { + s.MessageID = decodeUint16(b) + payloadLength := s.FixedHeader.RemainingLength - 2 + for payloadLength > 0 { + topic := decodeString(b) + s.Topics = append(s.Topics, topic) + qos := decodeByte(b) + s.Qoss = append(s.Qoss, qos) + payloadLength -= 2 + len(topic) + 1 //2 bytes of string length, plus string, plus 1 byte for Qos + } + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (s *SubscribePacket) Details() Details { + return Details{Qos: 1, MessageID: s.MessageID} +} diff --git a/packets/unsuback.go b/packets/unsuback.go new file mode 100644 index 0000000..b3b91ce --- /dev/null +++ b/packets/unsuback.go @@ -0,0 +1,44 @@ +package packets + +import ( + "fmt" + "io" +) + +//UnsubackPacket is an internal representation of the fields of the +//Unsuback MQTT packet +type UnsubackPacket struct { + FixedHeader + MessageID uint16 +} + +func (ua *UnsubackPacket) String() string { + str := fmt.Sprintf("%s", ua.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d", ua.MessageID) + return str +} + +func (ua *UnsubackPacket) Write(w io.Writer) error { + var err error + ua.FixedHeader.RemainingLength = 2 + packet := ua.FixedHeader.pack() + packet.Write(encodeUint16(ua.MessageID)) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (ua *UnsubackPacket) Unpack(b io.Reader) error { + ua.MessageID = decodeUint16(b) + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (ua *UnsubackPacket) Details() Details { + return Details{Qos: 0, MessageID: ua.MessageID} +} diff --git a/packets/unsubscribe.go b/packets/unsubscribe.go new file mode 100644 index 0000000..dc6a89e --- /dev/null +++ b/packets/unsubscribe.go @@ -0,0 +1,55 @@ +package packets + +import ( + "bytes" + "fmt" + "io" +) + +//UnsubscribePacket is an internal representation of the fields of the +//Unsubscribe MQTT packet +type UnsubscribePacket struct { + FixedHeader + MessageID uint16 + Topics []string +} + +func (u *UnsubscribePacket) String() string { + str := fmt.Sprintf("%s", u.FixedHeader) + str += " " + str += fmt.Sprintf("MessageID: %d", u.MessageID) + return str +} + +func (u *UnsubscribePacket) Write(w io.Writer) error { + var body bytes.Buffer + var err error + body.Write(encodeUint16(u.MessageID)) + for _, topic := range u.Topics { + body.Write(encodeString(topic)) + } + u.FixedHeader.RemainingLength = body.Len() + packet := u.FixedHeader.pack() + packet.Write(body.Bytes()) + _, err = packet.WriteTo(w) + + return err +} + +//Unpack decodes the details of a ControlPacket after the fixed +//header has been read +func (u *UnsubscribePacket) Unpack(b io.Reader) error { + u.MessageID = decodeUint16(b) + var topic string + for topic = decodeString(b); topic != ""; topic = decodeString(b) { + u.Topics = append(u.Topics, topic) + } + + return nil +} + +//Details returns a Details struct containing the Qos and +//MessageID of this ControlPacket +func (u *UnsubscribePacket) Details() Details { + return Details{Qos: 1, MessageID: u.MessageID} +}