add mutex for the listener to resolve data race

Signed-off-by: Mohammed Al Sahaf <msaa1990@gmail.com>
This commit is contained in:
Mohammed Al Sahaf 2025-08-04 00:40:52 +03:00
parent e6d44851b1
commit 7668108b5d
No known key found for this signature in database

View File

@ -19,9 +19,10 @@ import (
// mockServer represents a simple TCP server for testing // mockServer represents a simple TCP server for testing
type mockServer struct { type mockServer struct {
listener net.Listener listener net.Listener
listenerMu sync.RWMutex // Add this line
addr string addr string
messages []string messages []string
mu sync.RWMutex messagesMu sync.RWMutex
wg sync.WaitGroup wg sync.WaitGroup
ctx context.Context ctx context.Context
cancel context.CancelFunc cancel context.CancelFunc
@ -58,10 +59,18 @@ func (ms *mockServer) run() {
case <-ms.ctx.Done(): case <-ms.ctx.Done():
return return
default: default:
if l, ok := ms.listener.(*net.TCPListener); ok && l != nil { ms.listenerMu.RLock()
l.SetDeadline(time.Now().Add(100 * time.Millisecond)) l := ms.listener
if l == nil {
ms.listenerMu.RUnlock()
return
} }
conn, err := ms.listener.Accept() if tcpListener, ok := l.(*net.TCPListener); ok && tcpListener != nil {
tcpListener.SetDeadline(time.Now().Add(100 * time.Millisecond))
}
conn, err := l.Accept()
ms.listenerMu.RUnlock()
if err != nil { if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue continue
@ -100,16 +109,16 @@ func (ms *mockServer) handleConnection(conn net.Conn) {
return return
default: default:
line := scanner.Text() line := scanner.Text()
ms.mu.Lock() ms.messagesMu.Lock()
ms.messages = append(ms.messages, line) ms.messages = append(ms.messages, line)
ms.mu.Unlock() ms.messagesMu.Unlock()
} }
} }
} }
func (ms *mockServer) getMessages() []string { func (ms *mockServer) getMessages() []string {
ms.mu.RLock() ms.messagesMu.RLock()
defer ms.mu.RUnlock() defer ms.messagesMu.RUnlock()
result := make([]string, len(ms.messages)) result := make([]string, len(ms.messages))
copy(result, ms.messages) copy(result, ms.messages)
return result return result
@ -117,7 +126,11 @@ func (ms *mockServer) getMessages() []string {
func (ms *mockServer) close() { func (ms *mockServer) close() {
ms.cancel() ms.cancel()
ms.listener.Close() ms.listenerMu.Lock()
if ms.listener != nil {
ms.listener.Close()
}
ms.listenerMu.Unlock()
ms.wg.Wait() ms.wg.Wait()
} }
@ -131,7 +144,12 @@ func (ms *mockServer) stop() {
ms.connMu.Unlock() ms.connMu.Unlock()
// Then close the listener // Then close the listener
ms.listener.Close() ms.listenerMu.Lock()
if ms.listener != nil {
ms.listener.Close()
ms.listener = nil
}
ms.listenerMu.Unlock()
} }
func (ms *mockServer) restart(t *testing.T) { func (ms *mockServer) restart(t *testing.T) {
@ -139,12 +157,15 @@ func (ms *mockServer) restart(t *testing.T) {
if err != nil { if err != nil {
t.Fatalf("Failed to restart mock server: %v", err) t.Fatalf("Failed to restart mock server: %v", err)
} }
ms.mu.Lock()
ms.listener = listener
ms.listenerMu.Lock()
ms.listener = listener
ms.listenerMu.Unlock()
ms.messagesMu.Lock()
// Clear existing messages to track only new ones // Clear existing messages to track only new ones
ms.messages = nil ms.messages = nil
ms.mu.Unlock() ms.messagesMu.Unlock()
ms.wg.Add(1) ms.wg.Add(1)
go ms.run() go ms.run()
@ -383,9 +404,9 @@ func TestNetWriter_WALPersistence(t *testing.T) {
server.restart(t) server.restart(t)
// Clear received messages to track only new ones // Clear received messages to track only new ones
server.mu.Lock() server.messagesMu.Lock()
server.messages = nil server.messages = nil
server.mu.Unlock() server.messagesMu.Unlock()
// Second session: create new NetWriter instance (simulating restart after crash) // Second session: create new NetWriter instance (simulating restart after crash)
nw2 := &NetWriter{ nw2 := &NetWriter{
@ -1124,9 +1145,9 @@ func TestNetWriter_WALBufferingDuringOutage(t *testing.T) {
} }
// Store outage messages that might have been received before failure // Store outage messages that might have been received before failure
server.mu.RLock() server.messagesMu.RLock()
preRestartMessages := append([]string(nil), server.messages...) preRestartMessages := append([]string(nil), server.messages...)
server.mu.RUnlock() server.messagesMu.RUnlock()
// Restart server // Restart server
server.restart(t) server.restart(t)