diff --git a/modules/web/handler.go b/modules/web/handler.go index 843b17e8d1..d113bbba45 100644 --- a/modules/web/handler.go +++ b/modules/web/handler.go @@ -70,7 +70,8 @@ func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) { func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value { defer func() { - if err := recover(); err != nil { + if recovered := recover(); recovered != nil { + err := fmt.Errorf("%v\n%s", recovered, log.Stack(2)) log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err) panic(err) } @@ -117,7 +118,17 @@ func hasResponseBeenWritten(argsIn []reflect.Value) bool { return false } -func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) func(next http.Handler) http.Handler { +type middlewareProvider = func(next http.Handler) http.Handler + +func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) { + handler := endpoint + for i := len(middlewares) - 1; i >= 0; i-- { + handler = middlewares[i](handler).ServeHTTP + } + handler(w, r) +} + +func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider { return func(next http.Handler) http.Handler { h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { @@ -129,14 +140,14 @@ func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo // toHandlerProvider converts a handler to a handler provider // A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware -func toHandlerProvider(handler any) func(next http.Handler) http.Handler { +func toHandlerProvider(handler any) middlewareProvider { funcInfo := routing.GetFuncInfo(handler) fn := reflect.ValueOf(handler) if fn.Type().Kind() != reflect.Func { panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type())) } - if hp, ok := handler.(func(next http.Handler) http.Handler); ok { + if hp, ok := handler.(middlewareProvider); ok { return wrapHandlerProvider(hp, funcInfo) } else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok { return wrapHandlerProvider(hp, funcInfo) diff --git a/modules/web/router.go b/modules/web/router.go index 5374f82a23..5ef18e9679 100644 --- a/modules/web/router.go +++ b/modules/web/router.go @@ -18,6 +18,13 @@ import ( "github.com/go-chi/chi/v5" ) +// PreMiddlewareProvider is a special middleware provider which will be executed +// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting). +// A route can do something (e.g.: set middleware options) at the place where it is declared, +// and the code will be executed before other middlewares which are added before the declaration. +// Use cases: mark a route with some meta info, set some options for middlewares, etc. +type PreMiddlewareProvider func(next http.Handler) http.Handler + // Bind binding an obj to a handler's context data func Bind[T any](_ T) http.HandlerFunc { return func(resp http.ResponseWriter, req *http.Request) { @@ -41,7 +48,10 @@ func GetForm(dataStore reqctx.RequestDataStore) any { // Router defines a route based on chi's router type Router struct { - chiRouter *chi.Mux + chiRouter *chi.Mux + + afterRouting []any + curGroupPrefix string curMiddlewares []any } @@ -52,8 +62,9 @@ func NewRouter() *Router { return &Router{chiRouter: r} } -// Use supports two middlewares -func (r *Router) Use(middlewares ...any) { +// BeforeRouting adds middlewares which will be executed before the request path gets routed +// It should only be used for framework-level global middlewares when it needs to change request method & path. +func (r *Router) BeforeRouting(middlewares ...any) { for _, m := range middlewares { if !isNilOrFuncNil(m) { r.chiRouter.Use(toHandlerProvider(m)) @@ -61,7 +72,13 @@ func (r *Router) Use(middlewares ...any) { } } -// Group mounts a sub-Router along a `pattern` string. +// AfterRouting adds middlewares which will be executed after the request path gets routed +// It can see the routed path and resolved path parameters +func (r *Router) AfterRouting(middlewares ...any) { + r.afterRouting = append(r.afterRouting, middlewares...) +} + +// Group mounts a sub-router along a "pattern" string. func (r *Router) Group(pattern string, fn func(), middlewares ...any) { previousGroupPrefix := r.curGroupPrefix previousMiddlewares := r.curMiddlewares @@ -93,36 +110,54 @@ func isNilOrFuncNil(v any) bool { return r.Kind() == reflect.Func && r.IsNil() } -func wrapMiddlewareAndHandler(curMiddlewares, h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) { - handlerProviders := make([]func(http.Handler) http.Handler, 0, len(curMiddlewares)+len(h)+1) - for _, m := range curMiddlewares { - if !isNilOrFuncNil(m) { - handlerProviders = append(handlerProviders, toHandlerProvider(m)) +func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider { + for _, m := range middlewares { + if h, ok := m.(PreMiddlewareProvider); ok && h != nil { + all = append(all, toHandlerProvider(middlewareProvider(h))) } } + return all +} + +func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider { + for _, m := range middlewares { + if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) { + all = append(all, toHandlerProvider(m)) + } + } + return all +} + +func wrapMiddlewareAndHandler(useMiddlewares, curMiddlewares, h []any) (_ []middlewareProvider, _ http.HandlerFunc, hasPreMiddlewares bool) { if len(h) == 0 { panic("no endpoint handler provided") } - for i, m := range h { - if !isNilOrFuncNil(m) { - handlerProviders = append(handlerProviders, toHandlerProvider(m)) - } else if i == len(h)-1 { - panic("endpoint handler can't be nil") - } + if isNilOrFuncNil(h[len(h)-1]) { + panic("endpoint handler can't be nil") } + + handlerProviders := make([]middlewareProvider, 0, len(useMiddlewares)+len(curMiddlewares)+len(h)+1) + handlerProviders = wrapMiddlewareAppendPre(handlerProviders, useMiddlewares) + handlerProviders = wrapMiddlewareAppendPre(handlerProviders, curMiddlewares) + handlerProviders = wrapMiddlewareAppendPre(handlerProviders, h) + hasPreMiddlewares = len(handlerProviders) > 0 + handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, useMiddlewares) + handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, curMiddlewares) + handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, h) + middlewares := handlerProviders[:len(handlerProviders)-1] handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP mockPoint := RouterMockPoint(MockAfterMiddlewares) if mockPoint != nil { middlewares = append(middlewares, mockPoint) } - return middlewares, handlerFunc + return middlewares, handlerFunc, hasPreMiddlewares } // Methods adds the same handlers for multiple http "methods" (separated by ","). // If any method is invalid, the lower level router will panic. func (r *Router) Methods(methods, pattern string, h ...any) { - middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h) + middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h) fullPattern := r.getPattern(pattern) if strings.Contains(methods, ",") { methods := strings.SplitSeq(methods, ",") @@ -134,15 +169,19 @@ func (r *Router) Methods(methods, pattern string, h ...any) { } } -// Mount attaches another Router along ./pattern/* +// Mount attaches another Router along "/pattern/*" func (r *Router) Mount(pattern string, subRouter *Router) { - subRouter.Use(r.curMiddlewares...) - r.chiRouter.Mount(r.getPattern(pattern), subRouter.chiRouter) + handlerProviders := make([]middlewareProvider, 0, len(r.afterRouting)+len(r.curMiddlewares)) + handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.afterRouting) + handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.curMiddlewares) + handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.afterRouting) + handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.curMiddlewares) + r.chiRouter.With(handlerProviders...).Mount(r.getPattern(pattern), subRouter.chiRouter) } // Any delegate requests for all methods func (r *Router) Any(pattern string, h ...any) { - middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h) + middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h) r.chiRouter.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc) } @@ -178,12 +217,16 @@ func (r *Router) Patch(pattern string, h ...any) { // ServeHTTP implements http.Handler func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) { + // TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient. r.normalizeRequestPath(w, req, r.chiRouter) } // NotFound defines a handler to respond whenever a route could not be found. func (r *Router) NotFound(h http.HandlerFunc) { - r.chiRouter.NotFound(h) + middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, []any{h}) + r.chiRouter.NotFound(func(w http.ResponseWriter, r *http.Request) { + executeMiddlewaresHandler(w, r, middlewares, handlerFunc) + }) } func (r *Router) normalizeRequestPath(resp http.ResponseWriter, req *http.Request, next http.Handler) { diff --git a/modules/web/router_path.go b/modules/web/router_path.go index 64154c34a5..9e531346d1 100644 --- a/modules/web/router_path.go +++ b/modules/web/router_path.go @@ -27,11 +27,7 @@ func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request) for _, m := range g.matchers { if m.matchPath(chiCtx, path) { chiCtx.RoutePatterns = append(chiCtx.RoutePatterns, m.pattern) - handler := m.handlerFunc - for i := len(m.middlewares) - 1; i >= 0; i-- { - handler = m.middlewares[i](handler).ServeHTTP - } - handler(resp, req) + executeMiddlewaresHandler(resp, req, m.middlewares, m.handlerFunc) return } } @@ -67,7 +63,7 @@ type routerPathMatcher struct { pattern string re *regexp.Regexp params []routerPathParam - middlewares []func(http.Handler) http.Handler + middlewares []middlewareProvider handlerFunc http.HandlerFunc } @@ -111,7 +107,10 @@ func isValidMethod(name string) bool { } func newRouterPathMatcher(methods string, patternRegexp *RouterPathGroupPattern, h ...any) *routerPathMatcher { - middlewares, handlerFunc := wrapMiddlewareAndHandler(patternRegexp.middlewares, h) + middlewares, handlerFunc, hasPreMiddlewares := wrapMiddlewareAndHandler(nil, patternRegexp.middlewares, h) + if hasPreMiddlewares { + panic("pre-middlewares are not supported in router path matcher") + } p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc} for method := range strings.SplitSeq(methods, ",") { method = strings.TrimSpace(method) diff --git a/modules/web/router_test.go b/modules/web/router_test.go index ab5fbb502c..645e70b869 100644 --- a/modules/web/router_test.go +++ b/modules/web/router_test.go @@ -30,6 +30,71 @@ func chiURLParamsToMap(chiCtx *chi.Context) map[string]string { return util.Iif(len(m) == 0, nil, m) } +type testResult struct { + method string + pathParams map[string]string + handlerMarks []string + chiRoutePattern *string +} + +type testRecorder struct { + res testResult +} + +func (r *testRecorder) reset() { + r.res = testResult{} +} + +func (r *testRecorder) handle(optMark ...string) func(resp http.ResponseWriter, req *http.Request) { + mark := util.OptionalArg(optMark, "") + return func(resp http.ResponseWriter, req *http.Request) { + chiCtx := chi.RouteContext(req.Context()) + r.res.method = req.Method + r.res.pathParams = chiURLParamsToMap(chiCtx) + r.res.chiRoutePattern = new(chiCtx.RoutePattern()) + if mark != "" { + r.res.handlerMarks = append(r.res.handlerMarks, mark) + } + } +} + +func (r *testRecorder) provider(optMark ...string) func(next http.Handler) http.Handler { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) { + r.handle(optMark...)(resp, req) + next.ServeHTTP(resp, req) + }) + } +} + +func (r *testRecorder) stop(optMark ...string) func(resp http.ResponseWriter, req *http.Request) { + mark := util.OptionalArg(optMark, "") + return func(resp http.ResponseWriter, req *http.Request) { + if stop := req.FormValue("stop"); stop != "" && (mark == "" || mark == stop) { + r.handle(stop)(resp, req) + resp.WriteHeader(http.StatusOK) + } else if mark != "" { + r.res.handlerMarks = append(r.res.handlerMarks, mark) + } + } +} + +func (r *testRecorder) test(t *testing.T, rt *Router, methodPath string, expected testResult) { + r.reset() + methodPathFields := strings.Fields(methodPath) + req, err := http.NewRequest(methodPathFields[0], methodPathFields[1], nil) + assert.NoError(t, err) + + buff := &bytes.Buffer{} + httpRecorder := httptest.NewRecorder() + httpRecorder.Body = buff + rt.ServeHTTP(httpRecorder, req) + if expected.chiRoutePattern == nil { + r.res.chiRoutePattern = nil + } + assert.Equal(t, expected, r.res) +} + func TestPathProcessor(t *testing.T) { testProcess := func(pattern, uri string, expectedPathParams map[string]string) { chiCtx := chi.NewRouteContext() @@ -51,42 +116,10 @@ func TestPathProcessor(t *testing.T) { } func TestRouter(t *testing.T) { - buff := &bytes.Buffer{} - recorder := httptest.NewRecorder() - recorder.Body = buff - - type resultStruct struct { - method string - pathParams map[string]string - handlerMarks []string - chiRoutePattern *string - } - - var res resultStruct - h := func(optMark ...string) func(resp http.ResponseWriter, req *http.Request) { - mark := util.OptionalArg(optMark, "") - return func(resp http.ResponseWriter, req *http.Request) { - chiCtx := chi.RouteContext(req.Context()) - res.method = req.Method - res.pathParams = chiURLParamsToMap(chiCtx) - res.chiRoutePattern = new(chiCtx.RoutePattern()) - if mark != "" { - res.handlerMarks = append(res.handlerMarks, mark) - } - } - } - - stopMark := func(optMark ...string) func(resp http.ResponseWriter, req *http.Request) { - mark := util.OptionalArg(optMark, "") - return func(resp http.ResponseWriter, req *http.Request) { - if stop := req.FormValue("stop"); stop != "" && (mark == "" || mark == stop) { - h(stop)(resp, req) - resp.WriteHeader(http.StatusOK) - } else if mark != "" { - res.handlerMarks = append(res.handlerMarks, mark) - } - } - } + type resultStruct = testResult + resRecorder := &testRecorder{} + h := resRecorder.handle + stopMark := resRecorder.stop r := NewRouter() r.NotFound(h("not-found:/")) @@ -123,15 +156,7 @@ func TestRouter(t *testing.T) { testRoute := func(t *testing.T, methodPath string, expected resultStruct) { t.Run(methodPath, func(t *testing.T) { - res = resultStruct{} - methodPathFields := strings.Fields(methodPath) - req, err := http.NewRequest(methodPathFields[0], methodPathFields[1], nil) - assert.NoError(t, err) - r.ServeHTTP(recorder, req) - if expected.chiRoutePattern == nil { - res.chiRoutePattern = nil - } - assert.Equal(t, expected, res) + resRecorder.test(t, r, methodPath, expected) }) } @@ -273,3 +298,39 @@ func TestRouteNormalizePath(t *testing.T) { testPath("/v2/", paths{EscapedPath: "/v2", RawPath: "/v2", Path: "/v2"}) testPath("/v2/%2f", paths{EscapedPath: "/v2/%2f", RawPath: "/v2/%2f", Path: "/v2//"}) } + +func TestPreMiddlewareProvider(t *testing.T) { + resRecorder := &testRecorder{} + h := resRecorder.handle + p := resRecorder.provider + + root := NewRouter() + root.BeforeRouting(h("before-root")) + root.AfterRouting(h("root")) + root.Get("/a/1", h("mid"), PreMiddlewareProvider(p("pre-root")), h("end1")) + + sub := NewRouter() + sub.BeforeRouting(h("before-sub")) + sub.AfterRouting(h("sub")) + sub.Get("/2", h("mid"), PreMiddlewareProvider(p("pre-sub")), h("end2")) + sub.NotFound(h("not-found")) + + root.Mount("/a", sub) + + resRecorder.test(t, root, "GET /a/1", testResult{ + method: "GET", + handlerMarks: []string{"before-root", "pre-root", "root", "mid", "end1"}, + }) + resRecorder.test(t, root, "GET /a/2", testResult{ + method: "GET", + handlerMarks: []string{"before-root", "root", "before-sub", "pre-sub", "sub", "mid", "end2"}, + }) + resRecorder.test(t, root, "GET /no-such", testResult{ + method: "GET", + handlerMarks: []string{"before-root"}, + }) + resRecorder.test(t, root, "GET /a/no-such", testResult{ + method: "GET", + handlerMarks: []string{"before-root", "root", "before-sub", "sub", "not-found"}, + }) +} diff --git a/modules/web/routing/context.go b/modules/web/routing/context.go index d3eb98f83d..838abea158 100644 --- a/modules/web/routing/context.go +++ b/modules/web/routing/context.go @@ -44,7 +44,7 @@ func MarkLongPolling(resp http.ResponseWriter, req *http.Request) { } // UpdatePanicError updates a context's error info, a panic may be recovered by other middlewares, but we still need to know that. -func UpdatePanicError(ctx context.Context, err any) { +func UpdatePanicError(ctx context.Context, err error) { record, ok := ctx.Value(contextKey).(*requestRecord) if !ok { return diff --git a/modules/web/routing/logger_manager.go b/modules/web/routing/logger_manager.go index aa25ec3a27..2f767c3d66 100644 --- a/modules/web/routing/logger_manager.go +++ b/modules/web/routing/logger_manager.go @@ -5,11 +5,13 @@ package routing import ( "context" + "fmt" "net/http" "sync" "time" "code.gitea.io/gitea/modules/graceful" + "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/process" ) @@ -99,7 +101,7 @@ func (manager *requestRecordsManager) handler(next http.Handler) http.Handler { localPanicErr := recover() if localPanicErr != nil { record.lock.Lock() - record.panicError = localPanicErr + record.panicError = fmt.Errorf("%v\n%s", localPanicErr, log.Stack(2)) record.lock.Unlock() } diff --git a/modules/web/routing/requestrecord.go b/modules/web/routing/requestrecord.go index cc61fc4d34..888c3e5c2f 100644 --- a/modules/web/routing/requestrecord.go +++ b/modules/web/routing/requestrecord.go @@ -24,5 +24,5 @@ type requestRecord struct { // mutable fields isLongPolling bool funcInfo *FuncInfo - panicError any + panicError error } diff --git a/routers/api/actions/artifacts.go b/routers/api/actions/artifacts.go index 6bef6c03c0..d7e7203a85 100644 --- a/routers/api/actions/artifacts.go +++ b/routers/api/actions/artifacts.go @@ -103,7 +103,7 @@ func init() { func ArtifactsRoutes(prefix string) *web.Router { m := web.NewRouter() - m.Use(ArtifactContexter()) + m.AfterRouting(ArtifactContexter()) r := artifactRoutes{ prefix: prefix, diff --git a/routers/api/packages/api.go b/routers/api/packages/api.go index 71fee23c92..7f81242db4 100644 --- a/routers/api/packages/api.go +++ b/routers/api/packages/api.go @@ -94,7 +94,7 @@ func verifyAuth(r *web.Router, authMethods []auth.Method) { } authGroup := auth.NewGroup(authMethods...) - r.Use(func(ctx *context.Context) { + r.AfterRouting(func(ctx *context.Context) { var err error ctx.Doer, err = authGroup.Verify(ctx.Req, ctx.Resp, ctx, ctx.Session) if err != nil { @@ -111,7 +111,7 @@ func verifyAuth(r *web.Router, authMethods []auth.Method) { func CommonRoutes() *web.Router { r := web.NewRouter() - r.Use(context.PackageContexter()) + r.AfterRouting(context.PackageContexter()) verifyAuth(r, []auth.Method{ &auth.OAuth2{}, @@ -533,7 +533,7 @@ func CommonRoutes() *web.Router { func ContainerRoutes() *web.Router { r := web.NewRouter() - r.Use(context.PackageContexter()) + r.AfterRouting(context.PackageContexter()) verifyAuth(r, []auth.Method{ &auth.Basic{}, diff --git a/routers/api/v1/api.go b/routers/api/v1/api.go index 767e5533fd..e95e597932 100644 --- a/routers/api/v1/api.go +++ b/routers/api/v1/api.go @@ -77,7 +77,6 @@ import ( repo_model "code.gitea.io/gitea/models/repo" "code.gitea.io/gitea/models/unit" user_model "code.gitea.io/gitea/models/user" - "code.gitea.io/gitea/modules/graceful" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" api "code.gitea.io/gitea/modules/structs" @@ -756,13 +755,9 @@ func buildAuthGroup() *auth.Group { &auth.Basic{}, // FIXME: this should be removed once we don't allow basic auth in API ) if setting.Service.EnableReverseProxyAuthAPI { - group.Add(&auth.ReverseProxy{}) + group.Add(&auth.ReverseProxy{}) // TODO: does it still make sense to support reverse proxy auth in API? } - - if setting.IsWindows && auth_model.IsSSPIEnabled(graceful.GetManager().ShutdownContext()) { - group.Add(&auth.SSPI{}) // it MUST be the last, see the comment of SSPI - } - + // others: API doesn't support SSPI auth because the caller should use token return group } @@ -872,9 +867,9 @@ func checkDeprecatedAuthMethods(ctx *context.APIContext) { func Routes() *web.Router { m := web.NewRouter() - m.Use(securityHeaders()) + m.BeforeRouting(securityHeaders()) if setting.CORSConfig.Enabled { - m.Use(cors.Handler(cors.Options{ + m.BeforeRouting(cors.Handler(cors.Options{ AllowedOrigins: setting.CORSConfig.AllowDomain, AllowedMethods: setting.CORSConfig.Methods, AllowCredentials: setting.CORSConfig.AllowCredentials, @@ -882,14 +877,14 @@ func Routes() *web.Router { MaxAge: int(setting.CORSConfig.MaxAge.Seconds()), })) } - m.Use(context.APIContexter()) - m.Use(checkDeprecatedAuthMethods) + m.AfterRouting(context.APIContexter()) + m.AfterRouting(checkDeprecatedAuthMethods) // Get user from session if logged in. - m.Use(apiAuth(buildAuthGroup())) + m.AfterRouting(apiAuth(buildAuthGroup())) - m.Use(verifyAuthWithOptions(&common.VerifyOptions{ + m.AfterRouting(verifyAuthWithOptions(&common.VerifyOptions{ SignInRequired: setting.Service.RequireSignInViewStrict, })) diff --git a/routers/common/errpage.go b/routers/common/errpage.go index 2406cf443f..07760bcd18 100644 --- a/routers/common/errpage.go +++ b/routers/common/errpage.go @@ -53,18 +53,18 @@ func renderServerErrorPage(w http.ResponseWriter, req *http.Request, respCode in _, _ = io.Copy(w, outBuf) } -// RenderPanicErrorPage renders a 500 page, and it never panics -func RenderPanicErrorPage(w http.ResponseWriter, req *http.Request, err any) { - combinedErr := fmt.Sprintf("%v\n%s", err, log.Stack(2)) - log.Error("PANIC: %s", combinedErr) +// renderPanicErrorPage renders a 500 page with the recovered panic value, it handles the stack trace, and it never panics +func renderPanicErrorPage(w http.ResponseWriter, req *http.Request, recovered any) { + combinedErr := fmt.Errorf("%v\n%s", recovered, log.Stack(2)) + log.Error("PANIC: %v", combinedErr) defer func() { if err := recover(); err != nil { - log.Error("Panic occurs again when rendering error page: %v. Stack:\n%s", err, log.Stack(2)) + log.Error("Panic occurs again when rendering error page: %v. Stack:\n%s", combinedErr, log.Stack(2)) } }() - routing.UpdatePanicError(req.Context(), err) + routing.UpdatePanicError(req.Context(), combinedErr) plainMsg := "Internal Server Error" ctxData := middleware.GetContextData(req.Context()) @@ -72,7 +72,7 @@ func RenderPanicErrorPage(w http.ResponseWriter, req *http.Request, err any) { // Otherwise, the 500-page may cause new panics, eg: cache.GetContextWithData, it makes the developer&users couldn't find the original panic. user, _ := ctxData[middleware.ContextDataKeySignedUser].(*user_model.User) if !setting.IsProd || (user != nil && user.IsAdmin) { - plainMsg = "PANIC: " + combinedErr + plainMsg = "PANIC: " + combinedErr.Error() ctxData["ErrorMsg"] = plainMsg } renderServerErrorPage(w, req, http.StatusInternalServerError, tplStatus500, ctxData, plainMsg) diff --git a/routers/common/errpage_test.go b/routers/common/errpage_test.go index c50d45c296..b81803c015 100644 --- a/routers/common/errpage_test.go +++ b/routers/common/errpage_test.go @@ -22,7 +22,7 @@ func TestRenderPanicErrorPage(t *testing.T) { w := httptest.NewRecorder() req := &http.Request{URL: &url.URL{}, Header: http.Header{"Accept": []string{"text/html"}}} req = req.WithContext(reqctx.NewRequestContextForTest(t.Context())) - RenderPanicErrorPage(w, req, errors.New("fake panic error (for test only)")) + renderPanicErrorPage(w, req, errors.New("fake panic error (for test only)")) respContent := w.Body.String() assert.Contains(t, respContent, `class="page-content status-page-500"`) assert.Contains(t, respContent, ``) diff --git a/routers/common/middleware.go b/routers/common/middleware.go index 6bf430d361..9daffb04f1 100644 --- a/routers/common/middleware.go +++ b/routers/common/middleware.go @@ -5,13 +5,13 @@ package common import ( "fmt" - "log" "net/http" "strings" "code.gitea.io/gitea/modules/cache" "code.gitea.io/gitea/modules/gtprof" "code.gitea.io/gitea/modules/httplib" + "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/reqctx" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/web/routing" @@ -63,8 +63,8 @@ func RequestContextHandler() func(h http.Handler) http.Handler { }() defer func() { - if err := recover(); err != nil { - RenderPanicErrorPage(respWriter, req, err) // it should never panic + if recovered := recover(); recovered != nil { + renderPanicErrorPage(respWriter, req, recovered) // it should never panic, and it handles the stack trace internally } }() @@ -130,7 +130,7 @@ func MustInitSessioner() func(next http.Handler) http.Handler { Domain: setting.SessionConfig.Domain, }) if err != nil { - log.Fatalf("common.Sessioner failed: %v", err) + log.Fatal("common.Sessioner failed: %v", err) } return middleware } diff --git a/routers/init.go b/routers/init.go index 8874236a60..2ed7a57e5c 100644 --- a/routers/init.go +++ b/routers/init.go @@ -180,8 +180,9 @@ func InitWebInstalled(ctx context.Context) { // NormalRoutes represents non install routes func NormalRoutes() *web.Router { r := web.NewRouter() - r.Use(common.ProtocolMiddlewares()...) - r.Use(common.MaintenanceModeHandler()) + r.BeforeRouting(common.ProtocolMiddlewares()...) + + r.AfterRouting(common.MaintenanceModeHandler()) r.Mount("/", web_routers.Routes()) r.Mount("/api/v1", apiv1.Routes()) diff --git a/routers/install/routes.go b/routers/install/routes.go index 0914c921c0..6fd511cbd1 100644 --- a/routers/install/routes.go +++ b/routers/install/routes.go @@ -20,11 +20,12 @@ import ( // Routes registers the installation routes func Routes() *web.Router { base := web.NewRouter() - base.Use(common.ProtocolMiddlewares()...) + base.BeforeRouting(common.ProtocolMiddlewares()...) + base.Methods("GET, HEAD", "/assets/*", public.FileHandlerFunc()) r := web.NewRouter() - r.Use(common.MustInitSessioner(), installContexter()) + r.AfterRouting(common.MustInitSessioner(), installContexter()) r.Get("/", Install) // it must be on the root, because the "install.js" use the window.location to replace the "localhost" AppURL r.Post("/", web.Bind(forms.InstallForm{}), SubmitInstall) diff --git a/routers/private/internal.go b/routers/private/internal.go index 2d5436468b..d4918455f1 100644 --- a/routers/private/internal.go +++ b/routers/private/internal.go @@ -54,11 +54,11 @@ func bind[T any](_ T) any { // These APIs will be invoked by internal commands for example `gitea serv` and etc. func Routes() *web.Router { r := web.NewRouter() - r.Use(context.PrivateContexter()) - r.Use(authInternal) + r.AfterRouting(context.PrivateContexter()) + r.AfterRouting(authInternal) // Log the real ip address of the request from SSH is really helpful for diagnosing sometimes. // Since internal API will be sent only from Gitea sub commands and it's under control (checked by InternalToken), we can trust the headers. - r.Use(chi_middleware.RealIP) + r.AfterRouting(chi_middleware.RealIP) r.Get("/dummy", misc.DummyOK) r.Post("/ssh/authorized_keys", AuthorizedPublicKeyByContent) diff --git a/routers/web/githttp.go b/routers/web/githttp.go index 43d318c1a1..c738b68c8d 100644 --- a/routers/web/githttp.go +++ b/routers/web/githttp.go @@ -6,12 +6,9 @@ package web import ( "code.gitea.io/gitea/modules/web" "code.gitea.io/gitea/routers/web/repo" - "code.gitea.io/gitea/services/context" ) -func addOwnerRepoGitHTTPRouters(m *web.Router) { - // Some users want to use "web-based git client" to access Gitea's repositories, - // so the CORS handler and OPTIONS method are used. +func addOwnerRepoGitHTTPRouters(m *web.Router, middlewares ...any) { m.Group("/{username}/{reponame}", func() { m.Methods("POST,OPTIONS", "/git-upload-pack", repo.ServiceUploadPack) m.Methods("POST,OPTIONS", "/git-receive-pack", repo.ServiceReceivePack) @@ -25,5 +22,5 @@ func addOwnerRepoGitHTTPRouters(m *web.Router) { m.Methods("GET,OPTIONS", "/objects/{head:[0-9a-f]{2}}/{hash:[0-9a-f]{38,62}}", repo.GetLooseObject) m.Methods("GET,OPTIONS", "/objects/pack/pack-{file:[0-9a-f]{40,64}}.pack", repo.GetPackFile) m.Methods("GET,OPTIONS", "/objects/pack/pack-{file:[0-9a-f]{40,64}}.idx", repo.GetIdxFile) - }, repo.HTTPGitEnabledHandler, repo.CorsHandler(), optSignInFromAnyOrigin, context.UserAssignmentWeb()) + }, middlewares...) } diff --git a/routers/web/web.go b/routers/web/web.go index 182c6c595b..95a54b5244 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -15,6 +15,7 @@ import ( "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/metrics" "code.gitea.io/gitea/modules/public" + "code.gitea.io/gitea/modules/reqctx" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/storage" "code.gitea.io/gitea/modules/structs" @@ -90,32 +91,64 @@ func optionsCorsHandler() func(next http.Handler) http.Handler { } } -// The OAuth2 plugin is expected to be executed first, as it must ignore the user id stored -// in the session (if there is a user id stored in session other plugins might return the user -// object for that id). -// -// The Session plugin is expected to be executed second, in order to skip authentication -// for users that have already signed in. -func buildAuthGroup() *auth_service.Group { - group := auth_service.NewGroup() - group.Add(&auth_service.OAuth2{}) // FIXME: this should be removed and only applied in download and oauth related routers - group.Add(&auth_service.Basic{}) // FIXME: this should be removed and only applied in download and git/lfs routers - - if setting.Service.EnableReverseProxyAuth { - group.Add(&auth_service.ReverseProxy{}) // reverse-proxy should before Session, otherwise the header will be ignored if user has login - } - group.Add(&auth_service.Session{}) - - if setting.IsWindows && auth_model.IsSSPIEnabled(graceful.GetManager().ShutdownContext()) { - group.Add(&auth_service.SSPI{}) // it MUST be the last, see the comment of SSPI - } - - return group +type AuthMiddleware struct { + AllowOAuth2 web.PreMiddlewareProvider + AllowBasic web.PreMiddlewareProvider + MiddlewareHandler func(*context.Context) } -func webAuth(authMethod auth_service.Method) func(*context.Context) { - return func(ctx *context.Context) { - ar, err := common.AuthShared(ctx.Base, ctx.Session, authMethod) +func newWebAuthMiddleware() *AuthMiddleware { + type keyAllowOAuth2 struct{} + type keyAllowBasic struct{} + webAuth := &AuthMiddleware{} + + middlewareSetContextValue := func(key, val any) web.PreMiddlewareProvider { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + dataStore := reqctx.GetRequestDataStore(r.Context()) + dataStore.SetContextValue(key, val) + next.ServeHTTP(w, r) + }) + } + } + + webAuth.AllowBasic = middlewareSetContextValue(keyAllowBasic{}, true) + webAuth.AllowOAuth2 = middlewareSetContextValue(keyAllowOAuth2{}, true) + + enableSSPI := setting.IsWindows && auth_model.IsSSPIEnabled(graceful.GetManager().ShutdownContext()) + webAuth.MiddlewareHandler = func(ctx *context.Context) { + allowBasic := ctx.GetContextValue(keyAllowBasic{}) == true + allowOAuth2 := ctx.GetContextValue(keyAllowOAuth2{}) == true + + group := auth_service.NewGroup() + + // Most auth methods should ignore the user id stored in the session. + // If the auth succeeds, it must use the user id from the auth method to make sure the new login succeeds. + if allowOAuth2 { + group.Add(&auth_service.OAuth2{}) + } + if allowBasic { + group.Add(&auth_service.Basic{}) + } + + // Sessionless means the route's auth can be done without web ui, then it doesn't need to create a session + // For example: accessing git via http, access rss feeds, downloading attachments, etc + isSessionless := allowOAuth2 || allowBasic + + if setting.Service.EnableReverseProxyAuth { + // reverse-proxy should before Session, otherwise the header will be ignored if user has login + group.Add(&auth_service.ReverseProxy{CreateSession: !isSessionless}) + } + + // The Session plugin will skip authentication for users that have already signed in. + group.Add(&auth_service.Session{}) + + if enableSSPI { + // it MUST be the last, see the comment of SSPI + group.Add(&auth_service.SSPI{CreateSession: !isSessionless}) + } + + ar, err := common.AuthShared(ctx.Base, ctx.Session, group) if err != nil { log.Error("Failed to verify user: %v", err) ctx.HTTPError(http.StatusUnauthorized, "Failed to authenticate user") @@ -129,6 +162,7 @@ func webAuth(authMethod auth_service.Method) func(*context.Context) { _ = ctx.Session.Delete("uid") } } + return webAuth } // verifyAuthWithOptions checks authentication according to options @@ -223,6 +257,9 @@ const RouterMockPointBeforeWebRoutes = "before-web-routes" func Routes() *web.Router { routes := web.NewRouter() + // GetHead allows a HEAD request redirect to GET if HEAD method is not defined for that route + routes.BeforeRouting(chi_middleware.GetHead) + routes.Head("/", misc.DummyOK) // for health check - doesn't need to be passed through gzip handler routes.Methods("GET, HEAD, OPTIONS", "/assets/*", optionsCorsHandler(), public.FileHandlerFunc()) routes.Methods("GET, HEAD", "/avatars/*", avatarStorageHandler(setting.Avatar.Storage, "avatars", storage.Avatars)) @@ -260,10 +297,8 @@ func Routes() *web.Router { mid = append(mid, common.MustInitSessioner(), context.Contexter()) // Get user from session if logged in. - mid = append(mid, webAuth(buildAuthGroup())) - - // GetHead allows a HEAD request redirect to GET if HEAD method is not defined for that route - mid = append(mid, chi_middleware.GetHead) + webAuth := newWebAuthMiddleware() + mid = append(mid, webAuth.MiddlewareHandler) if setting.API.EnableSwagger { // Note: The route is here but no in API routes because it renders a web page @@ -272,10 +307,12 @@ func Routes() *web.Router { mid = append(mid, goGet) mid = append(mid, common.PageGlobalData) + mid = append(mid, common.BlockExpensive(), common.QoS(), web.RouterMockPoint(RouterMockPointBeforeWebRoutes)) webRoutes := web.NewRouter() - webRoutes.Use(mid...) - webRoutes.Group("", func() { registerWebRoutes(webRoutes) }, common.BlockExpensive(), common.QoS(), web.RouterMockPoint(RouterMockPointBeforeWebRoutes)) + webRoutes.AfterRouting(mid...) + registerWebRoutes(webRoutes, webAuth) + routes.Mount("", webRoutes) return routes } @@ -288,7 +325,7 @@ func Routes() *web.Router { var optSignInFromAnyOrigin = verifyAuthWithOptions(&common.VerifyOptions{DisableCrossOriginProtection: true}) // registerWebRoutes register routes -func registerWebRoutes(m *web.Router) { +func registerWebRoutes(m *web.Router, webAuth *AuthMiddleware) { // required to be signed in or signed out reqSignIn := verifyAuthWithOptions(&common.VerifyOptions{SignInRequired: true}) reqSignOut := verifyAuthWithOptions(&common.VerifyOptions{SignOutRequired: true}) @@ -565,7 +602,7 @@ func registerWebRoutes(m *web.Router) { m.Methods("POST, OPTIONS", "/access_token", web.Bind(forms.AccessTokenForm{}), auth.AccessTokenOAuth) m.Methods("GET, OPTIONS", "/keys", auth.OIDCKeys) m.Methods("POST, OPTIONS", "/introspect", web.Bind(forms.IntrospectTokenForm{}), auth.IntrospectOAuth) - }, optionsCorsHandler(), optSignInFromAnyOrigin) + }, optionsCorsHandler(), webAuth.AllowOAuth2, optSignInFromAnyOrigin) }, oauth2Enabled) m.Group("/user/settings", func() { @@ -816,8 +853,9 @@ func registerWebRoutes(m *web.Router) { // ***** END: Admin ***** m.Group("", func() { - m.Get("/{username}", user.UsernameSubRoute) - m.Methods("GET, OPTIONS", "/attachments/{uuid}", optionsCorsHandler(), repo.GetAttachment) + // it handles "username.rss" in the handler, so allow basic auth as other rss/atom routes + m.Get("/{username}", webAuth.AllowBasic, user.UsernameSubRoute) + m.Methods("GET, OPTIONS", "/attachments/{uuid}", optionsCorsHandler(), webAuth.AllowBasic, webAuth.AllowOAuth2, repo.GetAttachment) }, optSignIn) m.Post("/{username}", reqSignIn, context.UserAssignmentWeb(), user.ActionUserFollow) @@ -1188,7 +1226,7 @@ func registerWebRoutes(m *web.Router) { // end "/{username}/{reponame}/settings" // user/org home, including rss feeds like "/{username}/{reponame}.rss" - m.Get("/{username}/{reponame}", optSignIn, context.RepoAssignment, context.RepoRefByType(git.RefTypeBranch), repo.SetEditorconfigIfExists, repo.Home) + m.Get("/{username}/{reponame}", optSignIn, webAuth.AllowBasic, context.RepoAssignment, context.RepoRefByType(git.RefTypeBranch), repo.SetEditorconfigIfExists, repo.Home) m.Post("/{username}/{reponame}/markup", optSignIn, context.RepoAssignment, reqUnitsWithMarkdown, web.Bind(structs.MarkupOption{}), misc.Markup) @@ -1389,8 +1427,8 @@ func registerWebRoutes(m *web.Router) { m.Group("/{username}/{reponame}", func() { // repo tags m.Group("/tags", func() { m.Get("", context.RepoRefByDefaultBranch() /* for the "commits" tab */, repo.TagsList) - m.Get(".rss", feedEnabled, repo.TagsListFeedRSS) - m.Get(".atom", feedEnabled, repo.TagsListFeedAtom) + m.Get(".rss", webAuth.AllowBasic, feedEnabled, repo.TagsListFeedRSS) + m.Get(".atom", webAuth.AllowBasic, feedEnabled, repo.TagsListFeedAtom) m.Get("/list", repo.GetTagList) }, ctxDataSet("EnableFeed", setting.Other.EnableFeed)) m.Post("/tags/delete", reqSignIn, reqRepoCodeWriter, context.RepoMustNotBeArchived(), repo.DeleteTag) @@ -1400,13 +1438,13 @@ func registerWebRoutes(m *web.Router) { m.Group("/{username}/{reponame}", func() { // repo releases m.Group("/releases", func() { m.Get("", repo.Releases) - m.Get(".rss", feedEnabled, repo.ReleasesFeedRSS) - m.Get(".atom", feedEnabled, repo.ReleasesFeedAtom) + m.Get(".rss", webAuth.AllowBasic, feedEnabled, repo.ReleasesFeedRSS) + m.Get(".atom", webAuth.AllowBasic, feedEnabled, repo.ReleasesFeedAtom) m.Get("/tag/*", repo.SingleRelease) m.Get("/latest", repo.LatestRelease) }, ctxDataSet("EnableFeed", setting.Other.EnableFeed)) - m.Get("/releases/attachments/{uuid}", repo.GetAttachment) - m.Get("/releases/download/{vTag}/{fileName}", repo.RedirectDownload) + m.Get("/releases/attachments/{uuid}", webAuth.AllowBasic, webAuth.AllowOAuth2, repo.GetAttachment) + m.Get("/releases/download/{vTag}/{fileName}", webAuth.AllowBasic, webAuth.AllowOAuth2, repo.RedirectDownload) m.Group("/releases", func() { m.Get("/new", repo.NewRelease) m.Post("/new", web.Bind(forms.NewReleaseForm{}), repo.NewReleasePost) @@ -1423,7 +1461,7 @@ func registerWebRoutes(m *web.Router) { // end "/{username}/{reponame}": repo releases m.Group("/{username}/{reponame}", func() { // to maintain compatibility with old attachments - m.Get("/attachments/{uuid}", repo.GetAttachment) + m.Get("/attachments/{uuid}", webAuth.AllowBasic, webAuth.AllowOAuth2, repo.GetAttachment) }, optSignIn, context.RepoAssignment) // end "/{username}/{reponame}": compatibility with old attachments @@ -1492,7 +1530,7 @@ func registerWebRoutes(m *web.Router) { m.Post("/rerun", reqRepoActionsWriter, actions.Rerun) }) m.Group("/workflows/{workflow_name}", func() { - m.Get("/badge.svg", actions.GetWorkflowBadge) + m.Get("/badge.svg", webAuth.AllowBasic, webAuth.AllowOAuth2, actions.GetWorkflowBadge) }) }, optSignIn, context.RepoAssignment, repo.MustBeNotEmpty, reqRepoActionsReader, actions.MustEnableActions) // end "/{username}/{reponame}/actions" @@ -1578,7 +1616,7 @@ func registerWebRoutes(m *web.Router) { m.Group("/archive", func() { m.Get("/*", repo.Download) m.Post("/*", repo.InitiateDownload) - }, repo.MustBeNotEmpty, dlSourceEnabled) + }, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.MustBeNotEmpty, dlSourceEnabled) m.Group("/branches", func() { m.Get("/list", repo.GetBranchesList) @@ -1591,7 +1629,7 @@ func registerWebRoutes(m *web.Router) { m.Get("/tag/*", context.RepoRefByType(git.RefTypeTag), repo.SingleDownloadOrLFS) m.Get("/commit/*", context.RepoRefByType(git.RefTypeCommit), repo.SingleDownloadOrLFS) m.Get("/*", context.RepoRefByType(""), repo.SingleDownloadOrLFS) // "/*" route is deprecated, and kept for backward compatibility - }, repo.MustBeNotEmpty) + }, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.MustBeNotEmpty) m.Group("/raw", func() { m.Get("/blob/{sha}", repo.DownloadByID) @@ -1599,7 +1637,7 @@ func registerWebRoutes(m *web.Router) { m.Get("/tag/*", context.RepoRefByType(git.RefTypeTag), repo.SingleDownload) m.Get("/commit/*", context.RepoRefByType(git.RefTypeCommit), repo.SingleDownload) m.Get("/*", context.RepoRefByType(""), repo.SingleDownload) // "/*" route is deprecated, and kept for backward compatibility - }, repo.MustBeNotEmpty) + }, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.MustBeNotEmpty) m.Group("/render", func() { m.Get("/branch/*", context.RepoRefByType(git.RefTypeBranch), repo.RenderFile) @@ -1632,8 +1670,8 @@ func registerWebRoutes(m *web.Router) { m.Get("/cherry-pick/{sha:([a-f0-9]{7,64})$}", repo.SetEditorconfigIfExists, context.RepoRefByDefaultBranch(), repo.CherryPick) }, repo.MustBeNotEmpty) - m.Get("/rss/branch/*", context.RepoRefByType(git.RefTypeBranch), feedEnabled, feed.RenderBranchFeedRSS) - m.Get("/atom/branch/*", context.RepoRefByType(git.RefTypeBranch), feedEnabled, feed.RenderBranchFeedAtom) + m.Get("/rss/branch/*", context.RepoRefByType(git.RefTypeBranch), webAuth.AllowBasic, feedEnabled, feed.RenderBranchFeedRSS) + m.Get("/atom/branch/*", context.RepoRefByType(git.RefTypeBranch), webAuth.AllowBasic, feedEnabled, feed.RenderBranchFeedAtom) m.Group("/src", func() { m.Get("", func(ctx *context.Context) { ctx.Redirect(ctx.Repo.RepoLink) }) // there is no "{owner}/{repo}/src" page, so redirect to "{owner}/{repo}" to avoid 404 @@ -1660,9 +1698,14 @@ func registerWebRoutes(m *web.Router) { m.Post("/action/{action:accept_transfer|reject_transfer}", reqSignIn, repo.ActionTransfer) }, optSignIn, context.RepoAssignment) - common.AddOwnerRepoGitLFSRoutes(m, lfsServerEnabled, repo.CorsHandler(), optSignInFromAnyOrigin) // "/{username}/{reponame}/{lfs-paths}": git-lfs support, see also addOwnerRepoGitHTTPRouters + // git lfs uses its own jwt key, and it handles the token & auth by itself, it conflicts with the general "OAuth2" auth method + // pattern: "/{username}/{reponame}/{lfs-paths}": git-lfs support, see also addOwnerRepoGitHTTPRouters + common.AddOwnerRepoGitLFSRoutes(m, lfsServerEnabled, webAuth.AllowBasic, repo.CorsHandler(), optSignInFromAnyOrigin) - addOwnerRepoGitHTTPRouters(m) // "/{username}/{reponame}/{git-paths}": git http support + // Some users want to use "web-based git client" to access Gitea's repositories, + // so the CORS handler and OPTIONS method are used. + // pattern: "/{username}/{reponame}/{git-paths}": git http support + addOwnerRepoGitHTTPRouters(m, repo.HTTPGitEnabledHandler, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.CorsHandler(), optSignInFromAnyOrigin, context.UserAssignmentWeb()) m.Group("/notifications", func() { m.Get("", user.Notifications) diff --git a/services/auth/auth.go b/services/auth/auth.go index 90e2115bc5..ebe8277f77 100644 --- a/services/auth/auth.go +++ b/services/auth/auth.go @@ -8,38 +8,16 @@ import ( "errors" "fmt" "net/http" - "regexp" - "strings" - "sync" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/auth/webauthn" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/optional" "code.gitea.io/gitea/modules/session" - "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/web/middleware" user_service "code.gitea.io/gitea/services/user" ) -type globalVarsStruct struct { - gitRawOrAttachPathRe *regexp.Regexp - lfsPathRe *regexp.Regexp - archivePathRe *regexp.Regexp - feedPathRe *regexp.Regexp - feedRefPathRe *regexp.Regexp -} - -var globalVars = sync.OnceValue(func() *globalVarsStruct { - return &globalVarsStruct{ - gitRawOrAttachPathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/(?:(?:git-(?:(?:upload)|(?:receive))-pack$)|(?:info/refs$)|(?:HEAD$)|(?:objects/)|(?:raw/)|(?:releases/download/)|(?:attachments/))`), - lfsPathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/info/lfs/`), - archivePathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/archive/`), - feedPathRe: regexp.MustCompile(`^/[-.\w]+(/[-.\w]+)?\.(rss|atom)$`), // "/owner.rss" or "/owner/repo.atom" - feedRefPathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/(rss|atom)/`), // "/owner/repo/rss/branch/..." - } -}) - type ErrUserAuthMessage string func (e ErrUserAuthMessage) Error() string { @@ -60,66 +38,6 @@ func Init() { webauthn.Init() } -type authPathDetector struct { - req *http.Request - vars *globalVarsStruct -} - -func newAuthPathDetector(req *http.Request) *authPathDetector { - return &authPathDetector{req: req, vars: globalVars()} -} - -// isAPIPath returns true if the specified URL is an API path -func (a *authPathDetector) isAPIPath() bool { - return strings.HasPrefix(a.req.URL.Path, "/api/") -} - -// isAttachmentDownload check if request is a file download (GET) with URL to an attachment -func (a *authPathDetector) isAttachmentDownload() bool { - return strings.HasPrefix(a.req.URL.Path, "/attachments/") && a.req.Method == http.MethodGet -} - -func (a *authPathDetector) isFeedRequest(req *http.Request) bool { - if !setting.Other.EnableFeed { - return false - } - if req.Method != http.MethodGet { - return false - } - return a.vars.feedPathRe.MatchString(req.URL.Path) || a.vars.feedRefPathRe.MatchString(req.URL.Path) -} - -// isContainerPath checks if the request targets the container endpoint -func (a *authPathDetector) isContainerPath() bool { - return strings.HasPrefix(a.req.URL.Path, "/v2/") -} - -func (a *authPathDetector) isGitRawOrAttachPath() bool { - return a.vars.gitRawOrAttachPathRe.MatchString(a.req.URL.Path) -} - -func (a *authPathDetector) isGitRawOrAttachOrLFSPath() bool { - if a.isGitRawOrAttachPath() { - return true - } - if setting.LFS.StartServer { - return a.vars.lfsPathRe.MatchString(a.req.URL.Path) - } - return false -} - -func (a *authPathDetector) isArchivePath() bool { - return a.vars.archivePathRe.MatchString(a.req.URL.Path) -} - -func (a *authPathDetector) isAuthenticatedTokenRequest() bool { - switch a.req.URL.Path { - case "/login/oauth/userinfo", "/login/oauth/introspect": - return true - } - return false -} - // handleSignIn clears existing session variables and stores new ones for the specified user object func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore, user *user_model.User) { // We need to regenerate the session... diff --git a/services/auth/auth_test.go b/services/auth/auth_test.go deleted file mode 100644 index c45f312c90..0000000000 --- a/services/auth/auth_test.go +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2014 The Gogs Authors. All rights reserved. -// Copyright 2019 The Gitea Authors. All rights reserved. -// SPDX-License-Identifier: MIT - -package auth - -import ( - "net/http" - "testing" - - "code.gitea.io/gitea/modules/setting" - "code.gitea.io/gitea/modules/test" - - "github.com/stretchr/testify/assert" -) - -func Test_isGitRawOrLFSPath(t *testing.T) { - tests := []struct { - path string - - want bool - }{ - { - "/owner/repo/git-upload-pack", - true, - }, - { - "/owner/repo/git-receive-pack", - true, - }, - { - "/owner/repo/info/refs", - true, - }, - { - "/owner/repo/HEAD", - true, - }, - { - "/owner/repo/objects/info/alternates", - true, - }, - { - "/owner/repo/objects/info/http-alternates", - true, - }, - { - "/owner/repo/objects/info/packs", - true, - }, - { - "/owner/repo/objects/info/blahahsdhsdkla", - true, - }, - { - "/owner/repo/objects/01/23456789abcdef0123456789abcdef01234567", - true, - }, - { - "/owner/repo/objects/pack/pack-123456789012345678921234567893124567894.pack", - true, - }, - { - "/owner/repo/objects/pack/pack-0123456789abcdef0123456789abcdef0123456.idx", - true, - }, - { - "/owner/repo/raw/branch/foo/fanaso", - true, - }, - { - "/owner/repo/stars", - false, - }, - { - "/notowner", - false, - }, - { - "/owner/repo", - false, - }, - { - "/owner/repo/commit/123456789012345678921234567893124567894", - false, - }, - { - "/owner/repo/releases/download/tag/repo.tar.gz", - true, - }, - { - "/owner/repo/attachments/6d92a9ee-5d8b-4993-97c9-6181bdaa8955", - true, - }, - } - - defer test.MockVariableValue(&setting.LFS.StartServer)() - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, "http://localhost"+tt.path, nil) - setting.LFS.StartServer = false - assert.Equal(t, tt.want, newAuthPathDetector(req).isGitRawOrAttachOrLFSPath()) - - setting.LFS.StartServer = true - assert.Equal(t, tt.want, newAuthPathDetector(req).isGitRawOrAttachOrLFSPath()) - }) - } - - lfsTests := []string{ - "/owner/repo/info/lfs/", - "/owner/repo/info/lfs/objects/batch", - "/owner/repo/info/lfs/objects/oid/filename", - "/owner/repo/info/lfs/objects/oid", - "/owner/repo/info/lfs/objects", - "/owner/repo/info/lfs/verify", - "/owner/repo/info/lfs/locks", - "/owner/repo/info/lfs/locks/verify", - "/owner/repo/info/lfs/locks/123/unlock", - } - for _, tt := range lfsTests { - t.Run(tt, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodPost, tt, nil) - setting.LFS.StartServer = false - got := newAuthPathDetector(req).isGitRawOrAttachOrLFSPath() - assert.Equalf(t, setting.LFS.StartServer, got, "isGitOrLFSPath(%q) = %v, want %v, %v", tt, got, setting.LFS.StartServer, globalVars().gitRawOrAttachPathRe.MatchString(tt)) - - setting.LFS.StartServer = true - got = newAuthPathDetector(req).isGitRawOrAttachOrLFSPath() - assert.Equalf(t, setting.LFS.StartServer, got, "isGitOrLFSPath(%q) = %v, want %v", tt, got, setting.LFS.StartServer) - }) - } -} - -func Test_isFeedRequest(t *testing.T) { - tests := []struct { - want bool - path string - }{ - {true, "/user.rss"}, - {true, "/user/repo.atom"}, - {false, "/user/repo"}, - {false, "/use/repo/file.rss"}, - - {true, "/org/repo/rss/branch/xxx"}, - {true, "/org/repo/atom/tag/xxx"}, - {false, "/org/repo/branch/main/rss/any"}, - {false, "/org/atom/any"}, - } - for _, tt := range tests { - t.Run(tt.path, func(t *testing.T) { - req, _ := http.NewRequest(http.MethodGet, "http://localhost"+tt.path, nil) - assert.Equal(t, tt.want, newAuthPathDetector(req).isFeedRequest(req)) - }) - } -} diff --git a/services/auth/basic.go b/services/auth/basic.go index 3161d7f33d..dda6451c36 100644 --- a/services/auth/basic.go +++ b/services/auth/basic.go @@ -41,13 +41,6 @@ func (b *Basic) Name() string { } func (b *Basic) parseAuthBasic(req *http.Request) (ret struct{ authToken, uname, passwd string }) { - // Basic authentication should only fire on API, Feed, Download, Archives or on Git or LFSPaths - // Not all feed (rss/atom) clients feature the ability to add cookies or headers, so we need to allow basic auth for feeds - detector := newAuthPathDetector(req) - if !detector.isAPIPath() && !detector.isFeedRequest(req) && !detector.isContainerPath() && !detector.isAttachmentDownload() && !detector.isArchivePath() && !detector.isGitRawOrAttachOrLFSPath() { - return ret - } - authHeader := req.Header.Get("Authorization") if authHeader == "" { return ret diff --git a/services/auth/oauth2.go b/services/auth/oauth2.go index 86903b0ce1..877237e96e 100644 --- a/services/auth/oauth2.go +++ b/services/auth/oauth2.go @@ -152,13 +152,6 @@ func (o *OAuth2) userFromToken(ctx context.Context, tokenSHA string, store DataS // If verification is successful returns an existing user object. // Returns nil if verification fails. func (o *OAuth2) Verify(req *http.Request, w http.ResponseWriter, store DataStore, sess SessionStore) (*user_model.User, error) { - // These paths are not API paths, but we still want to check for tokens because they maybe in the API returned URLs - detector := newAuthPathDetector(req) - if !detector.isAPIPath() && !detector.isAttachmentDownload() && !detector.isAuthenticatedTokenRequest() && - !detector.isGitRawOrAttachPath() && !detector.isArchivePath() { - return nil, nil //nolint:nilnil // the auth method is not applicable - } - token, ok := parseToken(req) if !ok { return nil, nil //nolint:nilnil // the auth method is not applicable diff --git a/services/auth/reverseproxy.go b/services/auth/reverseproxy.go index 064b263a67..e9b08d75f2 100644 --- a/services/auth/reverseproxy.go +++ b/services/auth/reverseproxy.go @@ -29,7 +29,9 @@ const ReverseProxyMethodName = "reverse_proxy" // On successful authentication the proxy is expected to populate the username in the // "setting.ReverseProxyAuthUser" header. Optionally it can also populate the email of the // user in the "setting.ReverseProxyAuthEmail" header. -type ReverseProxy struct{} +type ReverseProxy struct { + CreateSession bool +} // getUserName extracts the username from the "setting.ReverseProxyAuthUser" header func (r *ReverseProxy) getUserName(req *http.Request) string { @@ -115,9 +117,7 @@ func (r *ReverseProxy) Verify(req *http.Request, w http.ResponseWriter, store Da } } - // Make sure requests to API paths, attachment downloads, git and LFS do not create a new session - detector := newAuthPathDetector(req) - if !detector.isAPIPath() && !detector.isAttachmentDownload() && !detector.isGitRawOrAttachOrLFSPath() { + if r.CreateSession { if sess != nil && (sess.Get("uid") == nil || sess.Get("uid").(int64) != user.ID) { handleSignIn(w, req, sess, user) } diff --git a/services/auth/sspi.go b/services/auth/sspi.go index 6450753935..c21978f55a 100644 --- a/services/auth/sspi.go +++ b/services/auth/sspi.go @@ -46,7 +46,9 @@ var ( // The SSPI plugin is expected to be executed last, as it returns 401 status code if negotiation // fails (or if negotiation should continue), which would prevent other authentication methods // to execute at all. -type SSPI struct{} +type SSPI struct { + CreateSession bool +} // Name represents the name of auth method func (s *SSPI) Name() string { @@ -118,9 +120,7 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore, } } - // Make sure requests to API paths and PWA resources do not create a new session - detector := newAuthPathDetector(req) - if !detector.isAPIPath() && !detector.isAttachmentDownload() { + if s.CreateSession { handleSignIn(w, req, sess, user) } @@ -147,18 +147,9 @@ func (s *SSPI) getConfig(ctx context.Context) (*sspi.Source, error) { } func (s *SSPI) shouldAuthenticate(req *http.Request) (shouldAuth bool) { - shouldAuth = false - path := strings.TrimSuffix(req.URL.Path, "/") - if path == "/user/login" { - if req.FormValue("user_name") != "" && req.FormValue("password") != "" { - shouldAuth = false - } else if req.FormValue("auth_with_sspi") == "1" { - shouldAuth = true - } - } else { - detector := newAuthPathDetector(req) - shouldAuth = detector.isAPIPath() || detector.isAttachmentDownload() - } + // SSPI is only applicable for login requests with "auth_with_sspi" form value set to "1" + // See the template code with "auth_with_sspi" + shouldAuth = req.URL.Path == "/user/login" && req.FormValue("auth_with_sspi") == "1" return shouldAuth } diff --git a/tests/integration/migrate_test.go b/tests/integration/migrate_test.go index 5521c786d9..8c8f053ede 100644 --- a/tests/integration/migrate_test.go +++ b/tests/integration/migrate_test.go @@ -22,6 +22,7 @@ import ( "code.gitea.io/gitea/models/unittest" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/gitrepo" + "code.gitea.io/gitea/modules/httplib" "code.gitea.io/gitea/modules/setting" "code.gitea.io/gitea/modules/structs" "code.gitea.io/gitea/services/migrations" @@ -126,7 +127,7 @@ func Test_MigrateFromGiteaToGitea(t *testing.T) { session := loginUser(t, owner.Name) token := getTokenForLoggedInUser(t, session, auth_model.AccessTokenScopeAll) - resp, err := http.Get("https://gitea.com/gitea") + resp, err := httplib.NewRequest("https://gitea.com/gitea/test_repo.git", "GET").SetReadWriteTimeout(5 * time.Second).Response() if err != nil || resp.StatusCode != http.StatusOK { if resp != nil { resp.Body.Close()