fix race condition

Signed-off-by: Mohammed Al Sahaf <msaa1990@gmail.com>
This commit is contained in:
Mohammed Al Sahaf 2025-08-03 04:21:22 +03:00
parent 9f586657e8
commit 7ac7ca3ff4
No known key found for this signature in database
2 changed files with 25 additions and 5 deletions

View File

@ -67,6 +67,7 @@ type NetWriter struct {
flushWg sync.WaitGroup flushWg sync.WaitGroup
lastProcessedOffset int64 lastProcessedOffset int64
mu sync.RWMutex mu sync.RWMutex
walMu sync.Mutex // synchronizes WAL read/write operations
} }
// CaddyModule returns the Caddy module information. // CaddyModule returns the Caddy module information.
@ -193,7 +194,7 @@ func (nw *NetWriter) loadLastProcessedOffset() {
// saveLastProcessedOffset saves the last processed offset to a metadata file // saveLastProcessedOffset saves the last processed offset to a metadata file
func (nw *NetWriter) saveLastProcessedOffset(cp *wal.ChunkPosition) { func (nw *NetWriter) saveLastProcessedOffset(cp *wal.ChunkPosition) {
// Create a unique offset by combining segment, block, and chunk offset // Create a unique offset by combining segment, block, and chunk offset
offset := (int64(cp.SegmentId) << 32) | (int64(cp.BlockNumber) << 16) | int64(cp.ChunkOffset) offset := (int64(cp.SegmentId) << 32) | (int64(cp.BlockNumber) << 16) | (cp.ChunkOffset)
nw.mu.Lock() nw.mu.Lock()
nw.lastProcessedOffset = offset nw.lastProcessedOffset = offset
@ -314,14 +315,20 @@ func (nw *NetWriter) backgroundFlusher() {
// processWALEntries processes all available entries in the WAL using a fresh reader // processWALEntries processes all available entries in the WAL using a fresh reader
func (nw *NetWriter) processWALEntries(writeToConn func([]byte) error) { func (nw *NetWriter) processWALEntries(writeToConn func([]byte) error) {
// Synchronize WAL access to prevent race conditions with writers
nw.walMu.Lock()
// Create a fresh reader to see all current entries // Create a fresh reader to see all current entries
reader := nw.wal.NewReader() reader := nw.wal.NewReader()
nw.walMu.Unlock()
processed := 0 processed := 0
skipped := 0 skipped := 0
nw.logger.Debug("processing available WAL entries") nw.logger.Debug("processing available WAL entries")
for { for {
nw.walMu.Lock()
data, cp, err := reader.Next() data, cp, err := reader.Next()
nw.walMu.Unlock()
if err == io.EOF { if err == io.EOF {
if processed > 0 { if processed > 0 {
nw.logger.Debug("processed WAL entries", "processed", processed, "skipped", skipped) nw.logger.Debug("processed WAL entries", "processed", processed, "skipped", skipped)
@ -343,7 +350,7 @@ func (nw *NetWriter) processWALEntries(writeToConn func([]byte) error) {
nw.mu.RUnlock() nw.mu.RUnlock()
// Create current entry offset for comparison // Create current entry offset for comparison
currentOffset := (int64(cp.SegmentId) << 32) | (int64(cp.BlockNumber) << 16) | int64(cp.ChunkOffset) currentOffset := (int64(cp.SegmentId) << 32) | (int64(cp.BlockNumber) << 16) | (cp.ChunkOffset)
nw.logger.Debug("found WAL entry", "segmentId", cp.SegmentId, "blockNumber", cp.BlockNumber, "chunkOffset", cp.ChunkOffset, "currentOffset", currentOffset, "lastProcessedOffset", lastProcessedOffset, "size", len(data)) nw.logger.Debug("found WAL entry", "segmentId", cp.SegmentId, "blockNumber", cp.BlockNumber, "chunkOffset", cp.ChunkOffset, "currentOffset", currentOffset, "lastProcessedOffset", lastProcessedOffset, "size", len(data))
if currentOffset <= lastProcessedOffset { if currentOffset <= lastProcessedOffset {
@ -380,12 +387,18 @@ func (nw *NetWriter) processWALEntry(data []byte, cp *wal.ChunkPosition, writeTo
func (nw *NetWriter) flushRemainingWALEntries(writeToConn func([]byte) error) { func (nw *NetWriter) flushRemainingWALEntries(writeToConn func([]byte) error) {
nw.logger.Info("flushing remaining WAL entries during shutdown") nw.logger.Info("flushing remaining WAL entries during shutdown")
// Synchronize WAL access to prevent race conditions with writers
nw.walMu.Lock()
// Create a fresh reader for shutdown processing // Create a fresh reader for shutdown processing
reader := nw.wal.NewReader() reader := nw.wal.NewReader()
nw.walMu.Unlock()
count := 0 count := 0
for { for {
nw.walMu.Lock()
data, cp, err := reader.Next() data, cp, err := reader.Next()
nw.walMu.Unlock()
if err == io.EOF { if err == io.EOF {
break break
} }
@ -400,7 +413,7 @@ func (nw *NetWriter) flushRemainingWALEntries(writeToConn func([]byte) error) {
nw.mu.RUnlock() nw.mu.RUnlock()
// Create current entry offset for comparison // Create current entry offset for comparison
currentOffset := (int64(cp.SegmentId) << 32) | (int64(cp.BlockNumber) << 16) | int64(cp.ChunkOffset) currentOffset := (int64(cp.SegmentId) << 32) | (int64(cp.BlockNumber) << 16) | (cp.ChunkOffset)
if currentOffset <= lastProcessedOffset { if currentOffset <= lastProcessedOffset {
// Already processed, skip // Already processed, skip
@ -446,6 +459,10 @@ func (w *netWriterConn) Write(p []byte) (n int, err error) {
w.nw.logger.Debug("writing to WAL", "size", len(p)) w.nw.logger.Debug("writing to WAL", "size", len(p))
// Synchronize WAL access to prevent race conditions
w.nw.walMu.Lock()
defer w.nw.walMu.Unlock()
// Write to WAL - this should be fast and non-blocking // Write to WAL - this should be fast and non-blocking
_, err = w.nw.wal.Write(p) _, err = w.nw.wal.Write(p)
if err != nil { if err != nil {
@ -473,14 +490,16 @@ func (w *netWriterConn) Close() error {
var errs []error var errs []error
// Sync and close WAL // Sync and close WAL with synchronization
if w.nw.wal != nil { if w.nw.wal != nil {
w.nw.walMu.Lock()
if err := w.nw.wal.Sync(); err != nil { if err := w.nw.wal.Sync(); err != nil {
errs = append(errs, fmt.Errorf("WAL sync error: %v", err)) errs = append(errs, fmt.Errorf("WAL sync error: %v", err))
} }
if err := w.nw.wal.Close(); err != nil { if err := w.nw.wal.Close(); err != nil {
errs = append(errs, fmt.Errorf("WAL close error: %v", err)) errs = append(errs, fmt.Errorf("WAL close error: %v", err))
} }
w.nw.walMu.Unlock()
} }
if len(errs) > 0 { if len(errs) > 0 {

View File

@ -885,7 +885,8 @@ func TestNetWriter_UnmarshalCaddyfile(t *testing.T) {
}, },
} }
for _, tt := range tests { for i := range tests {
tt := tests[i] //nolint:copylocks
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
d := caddyfile.NewTestDispenser(tt.input) d := caddyfile.NewTestDispenser(tt.input)
nw := &NetWriter{} nw := &NetWriter{}