diff --git a/http/middleware/README.md b/http/middleware/README.md index 9bdf1fc..caa3b35 100644 --- a/http/middleware/README.md +++ b/http/middleware/README.md @@ -15,11 +15,11 @@ mux.HandleFunc("GET /api/data", logRequests(timeRequests(requireAuth(requireAdmi Into organizable this: ```go -mw := middleware.New(logRequests, timeRequests) -mux.HandleFunc("GET /api/version", mw.Handle(getVersion)) +mw := middleware.WithMux(mux, logRequests, timeRequests) +mw.HandleFunc("GET /api/version", getVersion) -authMW := m.Use(requireAuth, requireAdmin) -mux.HandleFunc("GET /api/data", authMW.Handle(getData)) +authMW := mw.With(requireAuth, requireAdmin) +authMW.HandleFunc("GET /api/data", getData) ``` Using stdlib this: diff --git a/http/middleware/middleware.go b/http/middleware/middleware.go index c472028..49fd5f9 100644 --- a/http/middleware/middleware.go +++ b/http/middleware/middleware.go @@ -29,15 +29,6 @@ func New(middlewares ...Middleware) MiddlewareChain { return MiddlewareChain{middlewares: middlewares} } -// Use appends additional middleware to the chain -func (c MiddlewareChain) Use(middlewares ...Middleware) MiddlewareChain { - newMiddlewares := make([]Middleware, len(c.middlewares), len(c.middlewares)+len(middlewares)) - copy(newMiddlewares, c.middlewares) - newMiddlewares = append(newMiddlewares, middlewares...) - - return MiddlewareChain{middlewares: newMiddlewares} -} - // Handle composes middleware with the final handler func (c MiddlewareChain) Handle(handler http.HandlerFunc) http.HandlerFunc { if handler == nil { @@ -56,3 +47,70 @@ func (c MiddlewareChain) Handle(handler http.HandlerFunc) http.HandlerFunc { return result } + +// Use appends additional middleware to the chain +func (c MiddlewareChain) Use(middlewares ...Middleware) MiddlewareChain { + newMiddlewares := make([]Middleware, len(c.middlewares), len(c.middlewares)+len(middlewares)) + copy(newMiddlewares, c.middlewares) + newMiddlewares = append(newMiddlewares, middlewares...) + + return MiddlewareChain{middlewares: newMiddlewares} +} + +type Muxer interface { + Handle(path string, handler http.Handler) + HandleFunc(path string, handle func(w http.ResponseWriter, r *http.Request)) +} + +// MiddlewareMux enables inline chaining +type MiddlewareMux struct { + middlewares []Middleware + mux Muxer +} + +// WithMux wraps a mux such so that Handle and HandleFunc apply the middleware chain +func WithMux(mux Muxer, middlewares ...Middleware) MiddlewareMux { + return MiddlewareMux{ + middlewares: middlewares, + mux: mux, + } +} + +// With creates a new copy of the chain with the specified middleware appended +func (c MiddlewareMux) With(middlewares ...Middleware) MiddlewareMux { + newMiddlewares := make([]Middleware, len(c.middlewares), len(c.middlewares)+len(middlewares)) + copy(newMiddlewares, c.middlewares) + newMiddlewares = append(newMiddlewares, middlewares...) + + return MiddlewareMux{ + mux: c.mux, + middlewares: newMiddlewares, + } +} + +func (c MiddlewareMux) Handle(path string, handler http.Handler) { + c.mux.Handle(path, c.handle(handler.ServeHTTP)) +} + +func (c MiddlewareMux) HandleFunc(path string, handler http.HandlerFunc) { + c.mux.HandleFunc(path, c.handle(handler)) +} + +// Handle composes middleware with the final handler +func (c MiddlewareMux) handle(handler http.HandlerFunc) http.HandlerFunc { + if handler == nil { + panic("mw.New(...).Use(...).Handle(-->this<--) requires a handler") + } + + middlewares := make([]Middleware, len(c.middlewares)) + copy(middlewares, c.middlewares) + slices.Reverse(middlewares) + + // Apply middleware in forward order + result := handler + for _, m := range middlewares { + result = m(result) + } + + return result +}