mirror of
				https://github.com/caddyserver/caddy.git
				synced 2025-10-26 08:12:43 -04:00 
			
		
		
		
	request_id: Allow reusing ID from header (closes #2012)
This commit is contained in:
		
							parent
							
								
									50ab4fe11e
								
							
						
					
					
						commit
						e2997ac974
					
				| @ -16,6 +16,7 @@ package requestid | |||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" | 	"context" | ||||||
|  | 	"log" | ||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
| 	"github.com/google/uuid" | 	"github.com/google/uuid" | ||||||
| @ -24,12 +25,29 @@ import ( | |||||||
| 
 | 
 | ||||||
| // Handler is a middleware handler | // Handler is a middleware handler | ||||||
| type Handler struct { | type Handler struct { | ||||||
| 	Next httpserver.Handler | 	Next       httpserver.Handler | ||||||
|  | 	HeaderName string // (optional) header from which to read an existing ID | ||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { | func (h Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) (int, error) { | ||||||
| 	reqid := uuid.New().String() | 	var reqid uuid.UUID | ||||||
| 	c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid) | 
 | ||||||
|  | 	uuidFromHeader := r.Header.Get(h.HeaderName) | ||||||
|  | 	if h.HeaderName != "" && uuidFromHeader != "" { | ||||||
|  | 		// use the ID in the header field if it exists | ||||||
|  | 		var err error | ||||||
|  | 		reqid, err = uuid.Parse(uuidFromHeader) | ||||||
|  | 		if err != nil { | ||||||
|  | 			log.Printf("[NOTICE] Parsing request ID from %s header: %v", h.HeaderName, err) | ||||||
|  | 			reqid = uuid.New() | ||||||
|  | 		} | ||||||
|  | 	} else { | ||||||
|  | 		// otherwise, create a new one | ||||||
|  | 		reqid = uuid.New() | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	// set the request ID on the context | ||||||
|  | 	c := context.WithValue(r.Context(), httpserver.RequestIDCtxKey, reqid.String()) | ||||||
| 	r = r.WithContext(c) | 	r = r.WithContext(c) | ||||||
| 
 | 
 | ||||||
| 	return h.Next.ServeHTTP(w, r) | 	return h.Next.ServeHTTP(w, r) | ||||||
|  | |||||||
| @ -15,34 +15,53 @@ | |||||||
| package requestid | package requestid | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"context" |  | ||||||
| 	"net/http" | 	"net/http" | ||||||
|  | 	"net/http/httptest" | ||||||
| 	"testing" | 	"testing" | ||||||
| 
 | 
 | ||||||
| 	"github.com/google/uuid" |  | ||||||
| 	"github.com/mholt/caddy/caddyhttp/httpserver" | 	"github.com/mholt/caddy/caddyhttp/httpserver" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func TestRequestID(t *testing.T) { | func TestRequestIDHandler(t *testing.T) { | ||||||
| 	request, err := http.NewRequest("GET", "http://localhost/", nil) | 	handler := Handler{ | ||||||
|  | 		Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { | ||||||
|  | 			value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string) | ||||||
|  | 			if value == "" { | ||||||
|  | 				t.Error("Request ID should not be empty") | ||||||
|  | 			} | ||||||
|  | 			return 0, nil | ||||||
|  | 		}), | ||||||
|  | 	} | ||||||
|  | 
 | ||||||
|  | 	req, err := http.NewRequest("GET", "http://localhost/", nil) | ||||||
| 	if err != nil { | 	if err != nil { | ||||||
| 		t.Fatal("Could not create HTTP request:", err) | 		t.Fatal("Could not create HTTP request:", err) | ||||||
| 	} | 	} | ||||||
|  | 	rec := httptest.NewRecorder() | ||||||
| 
 | 
 | ||||||
| 	reqid := uuid.New().String() | 	handler.ServeHTTP(rec, req) | ||||||
| 
 | } | ||||||
| 	c := context.WithValue(request.Context(), httpserver.RequestIDCtxKey, reqid) | 
 | ||||||
| 
 | func TestRequestIDFromHeader(t *testing.T) { | ||||||
| 	request = request.WithContext(c) | 	headerName := "X-Request-ID" | ||||||
| 
 | 	headerValue := "71a75329-d9f9-4d25-957e-e689a7b68d78" | ||||||
| 	// See caddyhttp/replacer.go | 	handler := Handler{ | ||||||
| 	value, _ := request.Context().Value(httpserver.RequestIDCtxKey).(string) | 		Next: httpserver.HandlerFunc(func(w http.ResponseWriter, r *http.Request) (int, error) { | ||||||
| 
 | 			value, _ := r.Context().Value(httpserver.RequestIDCtxKey).(string) | ||||||
| 	if value == "" { | 			if value != headerValue { | ||||||
| 		t.Fatal("Request ID should not be empty") | 				t.Errorf("Request ID should be '%s' but got '%s'", headerValue, value) | ||||||
| 	} | 			} | ||||||
| 
 | 			return 0, nil | ||||||
| 	if value != reqid { | 		}), | ||||||
| 		t.Fatal("Request ID does not match") | 		HeaderName: headerName, | ||||||
| 	} | 	} | ||||||
|  | 
 | ||||||
|  | 	req, err := http.NewRequest("GET", "http://localhost/", nil) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Fatal("Could not create HTTP request:", err) | ||||||
|  | 	} | ||||||
|  | 	req.Header.Set(headerName, headerValue) | ||||||
|  | 	rec := httptest.NewRecorder() | ||||||
|  | 
 | ||||||
|  | 	handler.ServeHTTP(rec, req) | ||||||
| } | } | ||||||
|  | |||||||
| @ -27,14 +27,19 @@ func init() { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func setup(c *caddy.Controller) error { | func setup(c *caddy.Controller) error { | ||||||
|  | 	var headerName string | ||||||
|  | 
 | ||||||
| 	for c.Next() { | 	for c.Next() { | ||||||
| 		if c.NextArg() { | 		if c.NextArg() { | ||||||
| 			return c.ArgErr() //no arg expected. | 			headerName = c.Val() | ||||||
|  | 		} | ||||||
|  | 		if c.NextArg() { | ||||||
|  | 			return c.ArgErr() | ||||||
| 		} | 		} | ||||||
| 	} | 	} | ||||||
| 
 | 
 | ||||||
| 	httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { | 	httpserver.GetConfig(c).AddMiddleware(func(next httpserver.Handler) httpserver.Handler { | ||||||
| 		return Handler{Next: next} | 		return Handler{Next: next, HeaderName: headerName} | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	return nil | 	return nil | ||||||
|  | |||||||
| @ -45,7 +45,15 @@ func TestSetup(t *testing.T) { | |||||||
| } | } | ||||||
| 
 | 
 | ||||||
| func TestSetupWithArg(t *testing.T) { | func TestSetupWithArg(t *testing.T) { | ||||||
| 	c := caddy.NewTestController("http", `requestid abc`) | 	c := caddy.NewTestController("http", `requestid X-Request-ID`) | ||||||
|  | 	err := setup(c) | ||||||
|  | 	if err != nil { | ||||||
|  | 		t.Errorf("Expected no error, got: %v", err) | ||||||
|  | 	} | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func TestSetupWithTooManyArgs(t *testing.T) { | ||||||
|  | 	c := caddy.NewTestController("http", `requestid foo bar`) | ||||||
| 	err := setup(c) | 	err := setup(c) | ||||||
| 	if err == nil { | 	if err == nil { | ||||||
| 		t.Errorf("Expected an error, got: %v", err) | 		t.Errorf("Expected an error, got: %v", err) | ||||||
|  | |||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user