diff --git a/modules/logging/netwriter_test.go b/modules/logging/netwriter_test.go index c1f11e4f8..776456722 100644 --- a/modules/logging/netwriter_test.go +++ b/modules/logging/netwriter_test.go @@ -19,9 +19,10 @@ import ( // mockServer represents a simple TCP server for testing type mockServer struct { listener net.Listener + listenerMu sync.RWMutex // Add this line addr string messages []string - mu sync.RWMutex + messagesMu sync.RWMutex wg sync.WaitGroup ctx context.Context cancel context.CancelFunc @@ -58,10 +59,18 @@ func (ms *mockServer) run() { case <-ms.ctx.Done(): return default: - if l, ok := ms.listener.(*net.TCPListener); ok && l != nil { - l.SetDeadline(time.Now().Add(100 * time.Millisecond)) + ms.listenerMu.RLock() + l := ms.listener + if l == nil { + ms.listenerMu.RUnlock() + return } - conn, err := ms.listener.Accept() + if tcpListener, ok := l.(*net.TCPListener); ok && tcpListener != nil { + tcpListener.SetDeadline(time.Now().Add(100 * time.Millisecond)) + } + conn, err := l.Accept() + ms.listenerMu.RUnlock() + if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { continue @@ -100,16 +109,16 @@ func (ms *mockServer) handleConnection(conn net.Conn) { return default: line := scanner.Text() - ms.mu.Lock() + ms.messagesMu.Lock() ms.messages = append(ms.messages, line) - ms.mu.Unlock() + ms.messagesMu.Unlock() } } } func (ms *mockServer) getMessages() []string { - ms.mu.RLock() - defer ms.mu.RUnlock() + ms.messagesMu.RLock() + defer ms.messagesMu.RUnlock() result := make([]string, len(ms.messages)) copy(result, ms.messages) return result @@ -117,7 +126,11 @@ func (ms *mockServer) getMessages() []string { func (ms *mockServer) close() { ms.cancel() - ms.listener.Close() + ms.listenerMu.Lock() + if ms.listener != nil { + ms.listener.Close() + } + ms.listenerMu.Unlock() ms.wg.Wait() } @@ -131,7 +144,12 @@ func (ms *mockServer) stop() { ms.connMu.Unlock() // Then close the listener - ms.listener.Close() + ms.listenerMu.Lock() + if ms.listener != nil { + ms.listener.Close() + ms.listener = nil + } + ms.listenerMu.Unlock() } func (ms *mockServer) restart(t *testing.T) { @@ -139,12 +157,15 @@ func (ms *mockServer) restart(t *testing.T) { if err != nil { t.Fatalf("Failed to restart mock server: %v", err) } - ms.mu.Lock() - ms.listener = listener + ms.listenerMu.Lock() + ms.listener = listener + ms.listenerMu.Unlock() + + ms.messagesMu.Lock() // Clear existing messages to track only new ones ms.messages = nil - ms.mu.Unlock() + ms.messagesMu.Unlock() ms.wg.Add(1) go ms.run() @@ -383,9 +404,9 @@ func TestNetWriter_WALPersistence(t *testing.T) { server.restart(t) // Clear received messages to track only new ones - server.mu.Lock() + server.messagesMu.Lock() server.messages = nil - server.mu.Unlock() + server.messagesMu.Unlock() // Second session: create new NetWriter instance (simulating restart after crash) nw2 := &NetWriter{ @@ -1124,9 +1145,9 @@ func TestNetWriter_WALBufferingDuringOutage(t *testing.T) { } // Store outage messages that might have been received before failure - server.mu.RLock() + server.messagesMu.RLock() preRestartMessages := append([]string(nil), server.messages...) - server.mu.RUnlock() + server.messagesMu.RUnlock() // Restart server server.restart(t)