Skip to content

Commit 806601a

Browse files
authored
query: allow multiple session IDs per connection (#2)
1 parent d86be25 commit 806601a

File tree

2 files changed

+10
-9
lines changed

2 files changed

+10
-9
lines changed

src/query/packets.go

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,18 @@ import (
55
"errors"
66
"fmt"
77
"io"
8+
"net"
89
"main/src/util"
910
"math/rand"
1011
"strconv"
1112
"strings"
1213
)
1314

14-
func writeHandshakePacket(w io.Writer, sessionID int32) error {
15+
func writeHandshakePacket(w io.Writer, addr net.Addr, sessionID int32) error {
1516
challengeToken := strconv.FormatInt(int64(rand.Int31()), 10)
1617

1718
sessionsMutex.Lock()
18-
sessions[sessionID] = challengeToken
19+
sessions[addr.String()] = challengeToken
1920
sessionsMutex.Unlock()
2021

2122
// Type - byte
@@ -36,13 +37,13 @@ func writeHandshakePacket(w io.Writer, sessionID int32) error {
3637
return nil
3738
}
3839

39-
func readRequestPacket(r io.Reader, w io.Writer, sessionID int32) (bool, error) {
40+
func readRequestPacket(r io.Reader, w io.Writer, addr net.Addr, sessionID int32) (bool, error) {
4041
sessionsMutex.Lock()
4142

4243
defer sessionsMutex.Unlock()
4344

44-
if _, ok := sessions[sessionID]; !ok {
45-
return false, fmt.Errorf("query: invalid or expired session ID: %X", sessionID)
45+
if _, ok := sessions[addr.String()]; !ok {
46+
return false, fmt.Errorf("query: no currently active challenges for %s", addr.String())
4647
}
4748

4849
// Challenge Token - int32
@@ -53,7 +54,7 @@ func readRequestPacket(r io.Reader, w io.Writer, sessionID int32) (bool, error)
5354
return false, err
5455
}
5556

56-
if sessions[sessionID] != strconv.FormatInt(int64(challengeToken), 10) {
57+
if sessions[addr.String()] != strconv.FormatInt(int64(challengeToken), 10) {
5758
return false, fmt.Errorf("query: received challenge token did not match stored")
5859
}
5960
}

src/query/socket.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ import (
1313
var (
1414
socket net.PacketConn = nil
1515
conf *config.Config = nil
16-
sessions map[int32]string = make(map[int32]string)
16+
sessions map[string]string = make(map[string]string) // Map of net.Addr.String() to challenge string
1717
sessionsMutex *sync.Mutex = &sync.Mutex{}
1818
)
1919

@@ -76,15 +76,15 @@ func handlePacket(data []byte, addr net.Addr) {
7676
switch packetType {
7777
case 0x09: // Generate challenge token
7878
{
79-
if err = writeHandshakePacket(buf, sessionID); err != nil {
79+
if err = writeHandshakePacket(buf, addr, sessionID); err != nil {
8080
return
8181
}
8282

8383
break
8484
}
8585
case 0x00: // Request
8686
{
87-
isFullStat, err := readRequestPacket(r, buf, sessionID)
87+
isFullStat, err := readRequestPacket(r, buf, addr, sessionID)
8888

8989
if err != nil {
9090
return

0 commit comments

Comments
 (0)