@@ -5,17 +5,18 @@ import (
55 "errors"
66 "fmt"
77 "io"
8+ "net"
89 "main/src/util"
910 "math/rand"
1011 "strconv"
1112 "strings"
1213)
1314
14- func writeHandshakePacket (w io.Writer , sessionID int32 ) error {
15+ func writeHandshakePacket (w io.Writer , addr net. Addr , sessionID int32 ) error {
1516 challengeToken := strconv .FormatInt (int64 (rand .Int31 ()), 10 )
1617
1718 sessionsMutex .Lock ()
18- sessions [sessionID ] = challengeToken
19+ sessions [addr . String () ] = challengeToken
1920 sessionsMutex .Unlock ()
2021
2122 // Type - byte
@@ -36,13 +37,13 @@ func writeHandshakePacket(w io.Writer, sessionID int32) error {
3637 return nil
3738}
3839
39- func readRequestPacket (r io.Reader , w io.Writer , sessionID int32 ) (bool , error ) {
40+ func readRequestPacket (r io.Reader , w io.Writer , addr net. Addr , sessionID int32 ) (bool , error ) {
4041 sessionsMutex .Lock ()
4142
4243 defer sessionsMutex .Unlock ()
4344
44- if _ , ok := sessions [sessionID ]; ! ok {
45- return false , fmt .Errorf ("query: invalid or expired session ID: %X " , sessionID )
45+ if _ , ok := sessions [addr . String () ]; ! ok {
46+ return false , fmt .Errorf ("query: no currently active challenges for %s " , addr . String () )
4647 }
4748
4849 // Challenge Token - int32
@@ -53,7 +54,7 @@ func readRequestPacket(r io.Reader, w io.Writer, sessionID int32) (bool, error)
5354 return false , err
5455 }
5556
56- if sessions [sessionID ] != strconv .FormatInt (int64 (challengeToken ), 10 ) {
57+ if sessions [addr . String () ] != strconv .FormatInt (int64 (challengeToken ), 10 ) {
5758 return false , fmt .Errorf ("query: received challenge token did not match stored" )
5859 }
5960 }
0 commit comments