Skip to content

Commit

Permalink
handler.go: Start using cookies server-side too.
Browse files Browse the repository at this point in the history
  • Loading branch information
Sandertv committed May 4, 2024
1 parent 08602c9 commit 7a99b42
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 39 deletions.
4 changes: 2 additions & 2 deletions dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"fmt"
"github.com/sandertv/go-raknet/internal"
"log/slog"
"math/rand"
"math/rand/v2"
"net"
"sync/atomic"
"time"
Expand Down Expand Up @@ -179,7 +179,7 @@ func (dialer Dialer) dial(ctx context.Context, address string) (net.Conn, error)
}

// dialerID is a counter used to produce an ID for the client.
var dialerID = rand.Int63()
var dialerID = rand.Int64()

// Dial attempts to dial a RakNet connection to the address passed. The address
// may be either an IP address or a hostname, combined with a port that is
Expand Down
75 changes: 48 additions & 27 deletions handler.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
package raknet

import (
"encoding/binary"
"errors"
"fmt"
"github.com/sandertv/go-raknet/internal/message"
"hash/crc32"
"net"
"time"
)
Expand All @@ -14,7 +16,10 @@ type connectionHandler interface {
close(conn *Conn)
}

type listenerConnectionHandler struct{ l *Listener }
type listenerConnectionHandler struct {
l *Listener
cookieSalt uint32
}

var (
errUnexpectedCRA = errors.New("unexpected CONNECTION_REQUEST_ACCEPTED packet")
Expand All @@ -29,6 +34,20 @@ func (h listenerConnectionHandler) close(conn *Conn) {
h.l.connections.Delete(resolve(conn.raddr))
}

// cookie calculates a cookie for the net.Addr passed. It is calculated as a
// hash of the random cookie salt and the address.
func (h listenerConnectionHandler) cookie(addr net.Addr) uint32 {
udp, _ := addr.(*net.UDPAddr)
b := make([]byte, 6, 10)
binary.LittleEndian.PutUint32(b, h.cookieSalt)
binary.LittleEndian.PutUint16(b, uint16(udp.Port))
b = append(b, udp.IP...)
// CRC32 isn't cryptographically secure, but we don't really need that here.
// A new salt is calculated every time a Listener is created and we don't
// have any data that needs to protected. We just need a fast hash.
return crc32.ChecksumIEEE(b)
}

func (h listenerConnectionHandler) handleUnconnected(b []byte, addr net.Addr) error {
switch b[0] {
case message.IDUnconnectedPing, message.IDUnconnectedPingOpenConnections:
Expand All @@ -47,29 +66,6 @@ func (h listenerConnectionHandler) handleUnconnected(b []byte, addr net.Addr) er
return fmt.Errorf("unknown packet received (len=%v): %x", len(b), b)
}

func (h listenerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, err error) {
switch b[0] {
case message.IDConnectionRequest:
return true, h.handleConnectionRequest(conn, b[1:])
case message.IDConnectionRequestAccepted:
return true, errUnexpectedCRA
case message.IDNewIncomingConnection:
return true, h.handleNewIncomingConnection(conn)
case message.IDConnectedPing:
return true, handleConnectedPing(conn, b[1:])
case message.IDConnectedPong:
return true, handleConnectedPong(b[1:])
case message.IDDisconnectNotification:
conn.closeImmediately()
return true, nil
case message.IDDetectLostConnections:
// Let the other end know the connection is still alive.
return true, conn.send(&message.ConnectedPing{PingTime: timestamp()})
default:
return false, nil
}
}

// handleUnconnectedPing handles an unconnected ping packet stored in buffer b,
// coming from an address.
func (h listenerConnectionHandler) handleUnconnectedPing(b []byte, addr net.Addr) error {
Expand Down Expand Up @@ -97,18 +93,21 @@ func (h listenerConnectionHandler) handleOpenConnectionRequest1(b []byte, addr n
return fmt.Errorf("handle OPEN_CONNECTION_REQUEST_1: incompatible protocol version %v (listener protocol = %v)", pk.ClientProtocol, protocolVersion)
}

data, _ := (&message.OpenConnectionReply1{ServerGUID: h.l.id, ServerHasSecurity: false, MTU: mtuSize}).MarshalBinary()
data, _ := (&message.OpenConnectionReply1{ServerGUID: h.l.id, Cookie: h.cookie(addr), ServerHasSecurity: true, MTU: mtuSize}).MarshalBinary()
_, err := h.l.conn.WriteTo(data, addr)
return err
}

// handleOpenConnectionRequest2 handles an open connection request 2 packet
// stored in buffer b, coming from an address.
func (h listenerConnectionHandler) handleOpenConnectionRequest2(b []byte, addr net.Addr) error {
pk := &message.OpenConnectionRequest2{}
pk := &message.OpenConnectionRequest2{ServerHasSecurity: true}
if err := pk.UnmarshalBinary(b); err != nil {
return fmt.Errorf("read OPEN_CONNECTION_REQUEST_2: %w", err)
}
if expected := h.cookie(addr); pk.Cookie != expected {
return fmt.Errorf("handle OPEN_CONNECTION_REQUEST_2: invalid cookie '%x', expected '%x'", pk.Cookie, expected)
}
mtuSize := min(pk.MTU, maxMTUSize)

data, _ := (&message.OpenConnectionReply2{ServerGUID: h.l.id, ClientAddress: resolve(addr), MTU: mtuSize}).MarshalBinary()
Expand All @@ -135,10 +134,32 @@ func (h listenerConnectionHandler) handleOpenConnectionRequest2(b []byte, addr n
_ = conn.Close()
}
}()

return nil
}

func (h listenerConnectionHandler) handle(conn *Conn, b []byte) (handled bool, err error) {
switch b[0] {
case message.IDConnectionRequest:
return true, h.handleConnectionRequest(conn, b[1:])
case message.IDConnectionRequestAccepted:
return true, errUnexpectedCRA
case message.IDNewIncomingConnection:
return true, h.handleNewIncomingConnection(conn)
case message.IDConnectedPing:
return true, handleConnectedPing(conn, b[1:])
case message.IDConnectedPong:
return true, handleConnectedPong(b[1:])
case message.IDDisconnectNotification:
conn.closeImmediately()
return true, nil
case message.IDDetectLostConnections:
// Let the other end know the connection is still alive.
return true, conn.send(&message.ConnectedPing{PingTime: timestamp()})
default:
return false, nil
}
}

// handleConnectionRequest handles a connection request packet inside of buffer
// b. An error is returned if the packet was invalid.
func (h listenerConnectionHandler) handleConnectionRequest(conn *Conn, b []byte) error {
Expand Down
25 changes: 19 additions & 6 deletions internal/message/open_connection_request_2.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,13 @@ import (
)

type OpenConnectionRequest2 struct {
ServerAddress netip.AddrPort
MTU uint16
ClientGUID int64
ServerAddress netip.AddrPort
MTU uint16
ClientGUID int64
// ServerHasSecurity specifies if the server has security enabled (and thus
// a cookie must be sent in this packet). This field is NOT written in this
// packet, so it must be set appropriately even when calling
// UnmarshalBinary.
ServerHasSecurity bool
Cookie uint32
}
Expand All @@ -34,12 +38,21 @@ func (pk *OpenConnectionRequest2) MarshalBinary() (data []byte, err error) {
}

func (pk *OpenConnectionRequest2) UnmarshalBinary(data []byte) error {
if len(data) < 16 || len(data) < 26+addrSize(data[16:]) {
cookieOffset := 0
if pk.ServerHasSecurity {
cookieOffset = 5
}
if len(data) < 16+cookieOffset || len(data) < 26+cookieOffset+addrSize(data[16+cookieOffset:]) {
return io.ErrUnexpectedEOF
}
// Magic: 16 bytes.
offset := addrSize(data[16:])
pk.ServerAddress, _ = addr(data[16:])
if pk.ServerHasSecurity {
pk.Cookie = binary.BigEndian.Uint32(data[16:])
}
offset := cookieOffset
var n int
pk.ServerAddress, n = addr(data[16+offset:])
offset += n
pk.MTU = binary.BigEndian.Uint16(data[16+offset:])
pk.ClientGUID = int64(binary.BigEndian.Uint64(data[18+offset:]))
return nil
Expand Down
8 changes: 4 additions & 4 deletions listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"fmt"
"log/slog"
"math"
"math/rand"
"math/rand/v2"
"net"
"sync"
"sync/atomic"
Expand All @@ -30,7 +30,7 @@ type ListenConfig struct {
// methods as those implemented by the TCPListener in the net package. Listener
// implements the net.Listener interface.
type Listener struct {
h listenerConnectionHandler
h *listenerConnectionHandler

once sync.Once
closed chan struct{}
Expand Down Expand Up @@ -58,7 +58,7 @@ type Listener struct {
}

// listenerID holds the next ID to use for a Listener.
var listenerID = rand.Int63()
var listenerID = rand.Int64()

// Listen listens on the address passed and returns a listener that may be used
// to accept connections. If not successful, an error is returned. The address
Expand All @@ -84,7 +84,7 @@ func (l ListenConfig) Listen(address string) (*Listener, error) {
log: l.ErrorLog,
id: atomic.AddInt64(&listenerID, 1),
}
listener.h.l = listener
listener.h = &listenerConnectionHandler{l: listener, cookieSalt: rand.Uint32()}
if l.ErrorLog == nil {
listener.log = slog.Default()
}
Expand Down

0 comments on commit 7a99b42

Please sign in to comment.