add mutex for the listener to resolve data race

Signed-off-by: Mohammed Al Sahaf <msaa1990@gmail.com>
This commit is contained in:
Mohammed Al Sahaf 2025-08-04 00:40:52 +03:00
parent e6d44851b1
commit 7668108b5d
No known key found for this signature in database

View File

@ -19,9 +19,10 @@ import (
// mockServer represents a simple TCP server for testing
type mockServer struct {
listener net.Listener
listenerMu sync.RWMutex // Add this line
addr string
messages []string
mu sync.RWMutex
messagesMu sync.RWMutex
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
@ -58,10 +59,18 @@ func (ms *mockServer) run() {
case <-ms.ctx.Done():
return
default:
if l, ok := ms.listener.(*net.TCPListener); ok && l != nil {
l.SetDeadline(time.Now().Add(100 * time.Millisecond))
ms.listenerMu.RLock()
l := ms.listener
if l == nil {
ms.listenerMu.RUnlock()
return
}
conn, err := ms.listener.Accept()
if tcpListener, ok := l.(*net.TCPListener); ok && tcpListener != nil {
tcpListener.SetDeadline(time.Now().Add(100 * time.Millisecond))
}
conn, err := l.Accept()
ms.listenerMu.RUnlock()
if err != nil {
if netErr, ok := err.(net.Error); ok && netErr.Timeout() {
continue
@ -100,16 +109,16 @@ func (ms *mockServer) handleConnection(conn net.Conn) {
return
default:
line := scanner.Text()
ms.mu.Lock()
ms.messagesMu.Lock()
ms.messages = append(ms.messages, line)
ms.mu.Unlock()
ms.messagesMu.Unlock()
}
}
}
func (ms *mockServer) getMessages() []string {
ms.mu.RLock()
defer ms.mu.RUnlock()
ms.messagesMu.RLock()
defer ms.messagesMu.RUnlock()
result := make([]string, len(ms.messages))
copy(result, ms.messages)
return result
@ -117,7 +126,11 @@ func (ms *mockServer) getMessages() []string {
func (ms *mockServer) close() {
ms.cancel()
ms.listenerMu.Lock()
if ms.listener != nil {
ms.listener.Close()
}
ms.listenerMu.Unlock()
ms.wg.Wait()
}
@ -131,7 +144,12 @@ func (ms *mockServer) stop() {
ms.connMu.Unlock()
// Then close the listener
ms.listenerMu.Lock()
if ms.listener != nil {
ms.listener.Close()
ms.listener = nil
}
ms.listenerMu.Unlock()
}
func (ms *mockServer) restart(t *testing.T) {
@ -139,12 +157,15 @@ func (ms *mockServer) restart(t *testing.T) {
if err != nil {
t.Fatalf("Failed to restart mock server: %v", err)
}
ms.mu.Lock()
ms.listener = listener
ms.listenerMu.Lock()
ms.listener = listener
ms.listenerMu.Unlock()
ms.messagesMu.Lock()
// Clear existing messages to track only new ones
ms.messages = nil
ms.mu.Unlock()
ms.messagesMu.Unlock()
ms.wg.Add(1)
go ms.run()
@ -383,9 +404,9 @@ func TestNetWriter_WALPersistence(t *testing.T) {
server.restart(t)
// Clear received messages to track only new ones
server.mu.Lock()
server.messagesMu.Lock()
server.messages = nil
server.mu.Unlock()
server.messagesMu.Unlock()
// Second session: create new NetWriter instance (simulating restart after crash)
nw2 := &NetWriter{
@ -1124,9 +1145,9 @@ func TestNetWriter_WALBufferingDuringOutage(t *testing.T) {
}
// Store outage messages that might have been received before failure
server.mu.RLock()
server.messagesMu.RLock()
preRestartMessages := append([]string(nil), server.messages...)
server.mu.RUnlock()
server.messagesMu.RUnlock()
// Restart server
server.restart(t)