Refactor auth middleware (#36848)

Principles: let the caller decide what it needs, but not let the
framework (middleware) guess what it should do.

Then a lot of hacky code can be removed. And some FIXMEs can be fixed.

This PR introduces a new kind of middleware: "PreMiddleware", it will be
executed before all other middlewares on the same routing level, then a
route can declare its options for other middlewares.

By the way, allow the workflow badge to be accessed by Basic or OAuth2
auth.

Fixes: https://github.com/go-gitea/gitea/pull/36830
Fixes: https://github.com/go-gitea/gitea/issues/36859
This commit is contained in:
wxiaoguang
2026-03-08 17:59:46 +08:00
committed by GitHub
parent a0996cb229
commit 3f1ef703d5
25 changed files with 338 additions and 444 deletions

View File

@@ -70,7 +70,8 @@ func preCheckHandler(fn reflect.Value, argsIn []reflect.Value) {
func prepareHandleArgsIn(resp http.ResponseWriter, req *http.Request, fn reflect.Value, fnInfo *routing.FuncInfo) []reflect.Value {
defer func() {
if err := recover(); err != nil {
if recovered := recover(); recovered != nil {
err := fmt.Errorf("%v\n%s", recovered, log.Stack(2))
log.Error("unable to prepare handler arguments for %s: %v", fnInfo.String(), err)
panic(err)
}
@@ -117,7 +118,17 @@ func hasResponseBeenWritten(argsIn []reflect.Value) bool {
return false
}
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) func(next http.Handler) http.Handler {
type middlewareProvider = func(next http.Handler) http.Handler
func executeMiddlewaresHandler(w http.ResponseWriter, r *http.Request, middlewares []middlewareProvider, endpoint http.HandlerFunc) {
handler := endpoint
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler).ServeHTTP
}
handler(w, r)
}
func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo *routing.FuncInfo) middlewareProvider {
return func(next http.Handler) http.Handler {
h := hp(next) // this handle could be dynamically generated, so we can't use it for debug info
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
@@ -129,14 +140,14 @@ func wrapHandlerProvider[T http.Handler](hp func(next http.Handler) T, funcInfo
// toHandlerProvider converts a handler to a handler provider
// A handler provider is a function that takes a "next" http.Handler, it can be used as a middleware
func toHandlerProvider(handler any) func(next http.Handler) http.Handler {
func toHandlerProvider(handler any) middlewareProvider {
funcInfo := routing.GetFuncInfo(handler)
fn := reflect.ValueOf(handler)
if fn.Type().Kind() != reflect.Func {
panic(fmt.Sprintf("handler must be a function, but got %s", fn.Type()))
}
if hp, ok := handler.(func(next http.Handler) http.Handler); ok {
if hp, ok := handler.(middlewareProvider); ok {
return wrapHandlerProvider(hp, funcInfo)
} else if hp, ok := handler.(func(http.Handler) http.HandlerFunc); ok {
return wrapHandlerProvider(hp, funcInfo)

View File

@@ -18,6 +18,13 @@ import (
"github.com/go-chi/chi/v5"
)
// PreMiddlewareProvider is a special middleware provider which will be executed
// before other middlewares on the same "routing" level (AfterRouting/Group/Methods/Any, but not BeforeRouting).
// A route can do something (e.g.: set middleware options) at the place where it is declared,
// and the code will be executed before other middlewares which are added before the declaration.
// Use cases: mark a route with some meta info, set some options for middlewares, etc.
type PreMiddlewareProvider func(next http.Handler) http.Handler
// Bind binding an obj to a handler's context data
func Bind[T any](_ T) http.HandlerFunc {
return func(resp http.ResponseWriter, req *http.Request) {
@@ -41,7 +48,10 @@ func GetForm(dataStore reqctx.RequestDataStore) any {
// Router defines a route based on chi's router
type Router struct {
chiRouter *chi.Mux
chiRouter *chi.Mux
afterRouting []any
curGroupPrefix string
curMiddlewares []any
}
@@ -52,8 +62,9 @@ func NewRouter() *Router {
return &Router{chiRouter: r}
}
// Use supports two middlewares
func (r *Router) Use(middlewares ...any) {
// BeforeRouting adds middlewares which will be executed before the request path gets routed
// It should only be used for framework-level global middlewares when it needs to change request method & path.
func (r *Router) BeforeRouting(middlewares ...any) {
for _, m := range middlewares {
if !isNilOrFuncNil(m) {
r.chiRouter.Use(toHandlerProvider(m))
@@ -61,7 +72,13 @@ func (r *Router) Use(middlewares ...any) {
}
}
// Group mounts a sub-Router along a `pattern` string.
// AfterRouting adds middlewares which will be executed after the request path gets routed
// It can see the routed path and resolved path parameters
func (r *Router) AfterRouting(middlewares ...any) {
r.afterRouting = append(r.afterRouting, middlewares...)
}
// Group mounts a sub-router along a "pattern" string.
func (r *Router) Group(pattern string, fn func(), middlewares ...any) {
previousGroupPrefix := r.curGroupPrefix
previousMiddlewares := r.curMiddlewares
@@ -93,36 +110,54 @@ func isNilOrFuncNil(v any) bool {
return r.Kind() == reflect.Func && r.IsNil()
}
func wrapMiddlewareAndHandler(curMiddlewares, h []any) ([]func(http.Handler) http.Handler, http.HandlerFunc) {
handlerProviders := make([]func(http.Handler) http.Handler, 0, len(curMiddlewares)+len(h)+1)
for _, m := range curMiddlewares {
if !isNilOrFuncNil(m) {
handlerProviders = append(handlerProviders, toHandlerProvider(m))
func wrapMiddlewareAppendPre(all []middlewareProvider, middlewares []any) []middlewareProvider {
for _, m := range middlewares {
if h, ok := m.(PreMiddlewareProvider); ok && h != nil {
all = append(all, toHandlerProvider(middlewareProvider(h)))
}
}
return all
}
func wrapMiddlewareAppendNormal(all []middlewareProvider, middlewares []any) []middlewareProvider {
for _, m := range middlewares {
if _, ok := m.(PreMiddlewareProvider); !ok && !isNilOrFuncNil(m) {
all = append(all, toHandlerProvider(m))
}
}
return all
}
func wrapMiddlewareAndHandler(useMiddlewares, curMiddlewares, h []any) (_ []middlewareProvider, _ http.HandlerFunc, hasPreMiddlewares bool) {
if len(h) == 0 {
panic("no endpoint handler provided")
}
for i, m := range h {
if !isNilOrFuncNil(m) {
handlerProviders = append(handlerProviders, toHandlerProvider(m))
} else if i == len(h)-1 {
panic("endpoint handler can't be nil")
}
if isNilOrFuncNil(h[len(h)-1]) {
panic("endpoint handler can't be nil")
}
handlerProviders := make([]middlewareProvider, 0, len(useMiddlewares)+len(curMiddlewares)+len(h)+1)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, useMiddlewares)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, curMiddlewares)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, h)
hasPreMiddlewares = len(handlerProviders) > 0
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, useMiddlewares)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, curMiddlewares)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, h)
middlewares := handlerProviders[:len(handlerProviders)-1]
handlerFunc := handlerProviders[len(handlerProviders)-1](nil).ServeHTTP
mockPoint := RouterMockPoint(MockAfterMiddlewares)
if mockPoint != nil {
middlewares = append(middlewares, mockPoint)
}
return middlewares, handlerFunc
return middlewares, handlerFunc, hasPreMiddlewares
}
// Methods adds the same handlers for multiple http "methods" (separated by ",").
// If any method is invalid, the lower level router will panic.
func (r *Router) Methods(methods, pattern string, h ...any) {
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
fullPattern := r.getPattern(pattern)
if strings.Contains(methods, ",") {
methods := strings.SplitSeq(methods, ",")
@@ -134,15 +169,19 @@ func (r *Router) Methods(methods, pattern string, h ...any) {
}
}
// Mount attaches another Router along ./pattern/*
// Mount attaches another Router along "/pattern/*"
func (r *Router) Mount(pattern string, subRouter *Router) {
subRouter.Use(r.curMiddlewares...)
r.chiRouter.Mount(r.getPattern(pattern), subRouter.chiRouter)
handlerProviders := make([]middlewareProvider, 0, len(r.afterRouting)+len(r.curMiddlewares))
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.afterRouting)
handlerProviders = wrapMiddlewareAppendPre(handlerProviders, r.curMiddlewares)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.afterRouting)
handlerProviders = wrapMiddlewareAppendNormal(handlerProviders, r.curMiddlewares)
r.chiRouter.With(handlerProviders...).Mount(r.getPattern(pattern), subRouter.chiRouter)
}
// Any delegate requests for all methods
func (r *Router) Any(pattern string, h ...any) {
middlewares, handlerFunc := wrapMiddlewareAndHandler(r.curMiddlewares, h)
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, h)
r.chiRouter.With(middlewares...).HandleFunc(r.getPattern(pattern), handlerFunc)
}
@@ -178,12 +217,16 @@ func (r *Router) Patch(pattern string, h ...any) {
// ServeHTTP implements http.Handler
func (r *Router) ServeHTTP(w http.ResponseWriter, req *http.Request) {
// TODO: need to move it to the top-level common middleware, otherwise each "Mount" will cause it to be executed multiple times, which is inefficient.
r.normalizeRequestPath(w, req, r.chiRouter)
}
// NotFound defines a handler to respond whenever a route could not be found.
func (r *Router) NotFound(h http.HandlerFunc) {
r.chiRouter.NotFound(h)
middlewares, handlerFunc, _ := wrapMiddlewareAndHandler(r.afterRouting, r.curMiddlewares, []any{h})
r.chiRouter.NotFound(func(w http.ResponseWriter, r *http.Request) {
executeMiddlewaresHandler(w, r, middlewares, handlerFunc)
})
}
func (r *Router) normalizeRequestPath(resp http.ResponseWriter, req *http.Request, next http.Handler) {

View File

@@ -27,11 +27,7 @@ func (g *RouterPathGroup) ServeHTTP(resp http.ResponseWriter, req *http.Request)
for _, m := range g.matchers {
if m.matchPath(chiCtx, path) {
chiCtx.RoutePatterns = append(chiCtx.RoutePatterns, m.pattern)
handler := m.handlerFunc
for i := len(m.middlewares) - 1; i >= 0; i-- {
handler = m.middlewares[i](handler).ServeHTTP
}
handler(resp, req)
executeMiddlewaresHandler(resp, req, m.middlewares, m.handlerFunc)
return
}
}
@@ -67,7 +63,7 @@ type routerPathMatcher struct {
pattern string
re *regexp.Regexp
params []routerPathParam
middlewares []func(http.Handler) http.Handler
middlewares []middlewareProvider
handlerFunc http.HandlerFunc
}
@@ -111,7 +107,10 @@ func isValidMethod(name string) bool {
}
func newRouterPathMatcher(methods string, patternRegexp *RouterPathGroupPattern, h ...any) *routerPathMatcher {
middlewares, handlerFunc := wrapMiddlewareAndHandler(patternRegexp.middlewares, h)
middlewares, handlerFunc, hasPreMiddlewares := wrapMiddlewareAndHandler(nil, patternRegexp.middlewares, h)
if hasPreMiddlewares {
panic("pre-middlewares are not supported in router path matcher")
}
p := &routerPathMatcher{methods: make(container.Set[string]), middlewares: middlewares, handlerFunc: handlerFunc}
for method := range strings.SplitSeq(methods, ",") {
method = strings.TrimSpace(method)

View File

@@ -30,6 +30,71 @@ func chiURLParamsToMap(chiCtx *chi.Context) map[string]string {
return util.Iif(len(m) == 0, nil, m)
}
type testResult struct {
method string
pathParams map[string]string
handlerMarks []string
chiRoutePattern *string
}
type testRecorder struct {
res testResult
}
func (r *testRecorder) reset() {
r.res = testResult{}
}
func (r *testRecorder) handle(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
mark := util.OptionalArg(optMark, "")
return func(resp http.ResponseWriter, req *http.Request) {
chiCtx := chi.RouteContext(req.Context())
r.res.method = req.Method
r.res.pathParams = chiURLParamsToMap(chiCtx)
r.res.chiRoutePattern = new(chiCtx.RoutePattern())
if mark != "" {
r.res.handlerMarks = append(r.res.handlerMarks, mark)
}
}
}
func (r *testRecorder) provider(optMark ...string) func(next http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(resp http.ResponseWriter, req *http.Request) {
r.handle(optMark...)(resp, req)
next.ServeHTTP(resp, req)
})
}
}
func (r *testRecorder) stop(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
mark := util.OptionalArg(optMark, "")
return func(resp http.ResponseWriter, req *http.Request) {
if stop := req.FormValue("stop"); stop != "" && (mark == "" || mark == stop) {
r.handle(stop)(resp, req)
resp.WriteHeader(http.StatusOK)
} else if mark != "" {
r.res.handlerMarks = append(r.res.handlerMarks, mark)
}
}
}
func (r *testRecorder) test(t *testing.T, rt *Router, methodPath string, expected testResult) {
r.reset()
methodPathFields := strings.Fields(methodPath)
req, err := http.NewRequest(methodPathFields[0], methodPathFields[1], nil)
assert.NoError(t, err)
buff := &bytes.Buffer{}
httpRecorder := httptest.NewRecorder()
httpRecorder.Body = buff
rt.ServeHTTP(httpRecorder, req)
if expected.chiRoutePattern == nil {
r.res.chiRoutePattern = nil
}
assert.Equal(t, expected, r.res)
}
func TestPathProcessor(t *testing.T) {
testProcess := func(pattern, uri string, expectedPathParams map[string]string) {
chiCtx := chi.NewRouteContext()
@@ -51,42 +116,10 @@ func TestPathProcessor(t *testing.T) {
}
func TestRouter(t *testing.T) {
buff := &bytes.Buffer{}
recorder := httptest.NewRecorder()
recorder.Body = buff
type resultStruct struct {
method string
pathParams map[string]string
handlerMarks []string
chiRoutePattern *string
}
var res resultStruct
h := func(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
mark := util.OptionalArg(optMark, "")
return func(resp http.ResponseWriter, req *http.Request) {
chiCtx := chi.RouteContext(req.Context())
res.method = req.Method
res.pathParams = chiURLParamsToMap(chiCtx)
res.chiRoutePattern = new(chiCtx.RoutePattern())
if mark != "" {
res.handlerMarks = append(res.handlerMarks, mark)
}
}
}
stopMark := func(optMark ...string) func(resp http.ResponseWriter, req *http.Request) {
mark := util.OptionalArg(optMark, "")
return func(resp http.ResponseWriter, req *http.Request) {
if stop := req.FormValue("stop"); stop != "" && (mark == "" || mark == stop) {
h(stop)(resp, req)
resp.WriteHeader(http.StatusOK)
} else if mark != "" {
res.handlerMarks = append(res.handlerMarks, mark)
}
}
}
type resultStruct = testResult
resRecorder := &testRecorder{}
h := resRecorder.handle
stopMark := resRecorder.stop
r := NewRouter()
r.NotFound(h("not-found:/"))
@@ -123,15 +156,7 @@ func TestRouter(t *testing.T) {
testRoute := func(t *testing.T, methodPath string, expected resultStruct) {
t.Run(methodPath, func(t *testing.T) {
res = resultStruct{}
methodPathFields := strings.Fields(methodPath)
req, err := http.NewRequest(methodPathFields[0], methodPathFields[1], nil)
assert.NoError(t, err)
r.ServeHTTP(recorder, req)
if expected.chiRoutePattern == nil {
res.chiRoutePattern = nil
}
assert.Equal(t, expected, res)
resRecorder.test(t, r, methodPath, expected)
})
}
@@ -273,3 +298,39 @@ func TestRouteNormalizePath(t *testing.T) {
testPath("/v2/", paths{EscapedPath: "/v2", RawPath: "/v2", Path: "/v2"})
testPath("/v2/%2f", paths{EscapedPath: "/v2/%2f", RawPath: "/v2/%2f", Path: "/v2//"})
}
func TestPreMiddlewareProvider(t *testing.T) {
resRecorder := &testRecorder{}
h := resRecorder.handle
p := resRecorder.provider
root := NewRouter()
root.BeforeRouting(h("before-root"))
root.AfterRouting(h("root"))
root.Get("/a/1", h("mid"), PreMiddlewareProvider(p("pre-root")), h("end1"))
sub := NewRouter()
sub.BeforeRouting(h("before-sub"))
sub.AfterRouting(h("sub"))
sub.Get("/2", h("mid"), PreMiddlewareProvider(p("pre-sub")), h("end2"))
sub.NotFound(h("not-found"))
root.Mount("/a", sub)
resRecorder.test(t, root, "GET /a/1", testResult{
method: "GET",
handlerMarks: []string{"before-root", "pre-root", "root", "mid", "end1"},
})
resRecorder.test(t, root, "GET /a/2", testResult{
method: "GET",
handlerMarks: []string{"before-root", "root", "before-sub", "pre-sub", "sub", "mid", "end2"},
})
resRecorder.test(t, root, "GET /no-such", testResult{
method: "GET",
handlerMarks: []string{"before-root"},
})
resRecorder.test(t, root, "GET /a/no-such", testResult{
method: "GET",
handlerMarks: []string{"before-root", "root", "before-sub", "sub", "not-found"},
})
}

View File

@@ -44,7 +44,7 @@ func MarkLongPolling(resp http.ResponseWriter, req *http.Request) {
}
// UpdatePanicError updates a context's error info, a panic may be recovered by other middlewares, but we still need to know that.
func UpdatePanicError(ctx context.Context, err any) {
func UpdatePanicError(ctx context.Context, err error) {
record, ok := ctx.Value(contextKey).(*requestRecord)
if !ok {
return

View File

@@ -5,11 +5,13 @@ package routing
import (
"context"
"fmt"
"net/http"
"sync"
"time"
"code.gitea.io/gitea/modules/graceful"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/process"
)
@@ -99,7 +101,7 @@ func (manager *requestRecordsManager) handler(next http.Handler) http.Handler {
localPanicErr := recover()
if localPanicErr != nil {
record.lock.Lock()
record.panicError = localPanicErr
record.panicError = fmt.Errorf("%v\n%s", localPanicErr, log.Stack(2))
record.lock.Unlock()
}

View File

@@ -24,5 +24,5 @@ type requestRecord struct {
// mutable fields
isLongPolling bool
funcInfo *FuncInfo
panicError any
panicError error
}

View File

@@ -103,7 +103,7 @@ func init() {
func ArtifactsRoutes(prefix string) *web.Router {
m := web.NewRouter()
m.Use(ArtifactContexter())
m.AfterRouting(ArtifactContexter())
r := artifactRoutes{
prefix: prefix,

View File

@@ -94,7 +94,7 @@ func verifyAuth(r *web.Router, authMethods []auth.Method) {
}
authGroup := auth.NewGroup(authMethods...)
r.Use(func(ctx *context.Context) {
r.AfterRouting(func(ctx *context.Context) {
var err error
ctx.Doer, err = authGroup.Verify(ctx.Req, ctx.Resp, ctx, ctx.Session)
if err != nil {
@@ -111,7 +111,7 @@ func verifyAuth(r *web.Router, authMethods []auth.Method) {
func CommonRoutes() *web.Router {
r := web.NewRouter()
r.Use(context.PackageContexter())
r.AfterRouting(context.PackageContexter())
verifyAuth(r, []auth.Method{
&auth.OAuth2{},
@@ -533,7 +533,7 @@ func CommonRoutes() *web.Router {
func ContainerRoutes() *web.Router {
r := web.NewRouter()
r.Use(context.PackageContexter())
r.AfterRouting(context.PackageContexter())
verifyAuth(r, []auth.Method{
&auth.Basic{},

View File

@@ -77,7 +77,6 @@ import (
repo_model "code.gitea.io/gitea/models/repo"
"code.gitea.io/gitea/models/unit"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/graceful"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
api "code.gitea.io/gitea/modules/structs"
@@ -756,13 +755,9 @@ func buildAuthGroup() *auth.Group {
&auth.Basic{}, // FIXME: this should be removed once we don't allow basic auth in API
)
if setting.Service.EnableReverseProxyAuthAPI {
group.Add(&auth.ReverseProxy{})
group.Add(&auth.ReverseProxy{}) // TODO: does it still make sense to support reverse proxy auth in API?
}
if setting.IsWindows && auth_model.IsSSPIEnabled(graceful.GetManager().ShutdownContext()) {
group.Add(&auth.SSPI{}) // it MUST be the last, see the comment of SSPI
}
// others: API doesn't support SSPI auth because the caller should use token
return group
}
@@ -872,9 +867,9 @@ func checkDeprecatedAuthMethods(ctx *context.APIContext) {
func Routes() *web.Router {
m := web.NewRouter()
m.Use(securityHeaders())
m.BeforeRouting(securityHeaders())
if setting.CORSConfig.Enabled {
m.Use(cors.Handler(cors.Options{
m.BeforeRouting(cors.Handler(cors.Options{
AllowedOrigins: setting.CORSConfig.AllowDomain,
AllowedMethods: setting.CORSConfig.Methods,
AllowCredentials: setting.CORSConfig.AllowCredentials,
@@ -882,14 +877,14 @@ func Routes() *web.Router {
MaxAge: int(setting.CORSConfig.MaxAge.Seconds()),
}))
}
m.Use(context.APIContexter())
m.Use(checkDeprecatedAuthMethods)
m.AfterRouting(context.APIContexter())
m.AfterRouting(checkDeprecatedAuthMethods)
// Get user from session if logged in.
m.Use(apiAuth(buildAuthGroup()))
m.AfterRouting(apiAuth(buildAuthGroup()))
m.Use(verifyAuthWithOptions(&common.VerifyOptions{
m.AfterRouting(verifyAuthWithOptions(&common.VerifyOptions{
SignInRequired: setting.Service.RequireSignInViewStrict,
}))

View File

@@ -53,18 +53,18 @@ func renderServerErrorPage(w http.ResponseWriter, req *http.Request, respCode in
_, _ = io.Copy(w, outBuf)
}
// RenderPanicErrorPage renders a 500 page, and it never panics
func RenderPanicErrorPage(w http.ResponseWriter, req *http.Request, err any) {
combinedErr := fmt.Sprintf("%v\n%s", err, log.Stack(2))
log.Error("PANIC: %s", combinedErr)
// renderPanicErrorPage renders a 500 page with the recovered panic value, it handles the stack trace, and it never panics
func renderPanicErrorPage(w http.ResponseWriter, req *http.Request, recovered any) {
combinedErr := fmt.Errorf("%v\n%s", recovered, log.Stack(2))
log.Error("PANIC: %v", combinedErr)
defer func() {
if err := recover(); err != nil {
log.Error("Panic occurs again when rendering error page: %v. Stack:\n%s", err, log.Stack(2))
log.Error("Panic occurs again when rendering error page: %v. Stack:\n%s", combinedErr, log.Stack(2))
}
}()
routing.UpdatePanicError(req.Context(), err)
routing.UpdatePanicError(req.Context(), combinedErr)
plainMsg := "Internal Server Error"
ctxData := middleware.GetContextData(req.Context())
@@ -72,7 +72,7 @@ func RenderPanicErrorPage(w http.ResponseWriter, req *http.Request, err any) {
// Otherwise, the 500-page may cause new panics, eg: cache.GetContextWithData, it makes the developer&users couldn't find the original panic.
user, _ := ctxData[middleware.ContextDataKeySignedUser].(*user_model.User)
if !setting.IsProd || (user != nil && user.IsAdmin) {
plainMsg = "PANIC: " + combinedErr
plainMsg = "PANIC: " + combinedErr.Error()
ctxData["ErrorMsg"] = plainMsg
}
renderServerErrorPage(w, req, http.StatusInternalServerError, tplStatus500, ctxData, plainMsg)

View File

@@ -22,7 +22,7 @@ func TestRenderPanicErrorPage(t *testing.T) {
w := httptest.NewRecorder()
req := &http.Request{URL: &url.URL{}, Header: http.Header{"Accept": []string{"text/html"}}}
req = req.WithContext(reqctx.NewRequestContextForTest(t.Context()))
RenderPanicErrorPage(w, req, errors.New("fake panic error (for test only)"))
renderPanicErrorPage(w, req, errors.New("fake panic error (for test only)"))
respContent := w.Body.String()
assert.Contains(t, respContent, `class="page-content status-page-500"`)
assert.Contains(t, respContent, `</html>`)

View File

@@ -5,13 +5,13 @@ package common
import (
"fmt"
"log"
"net/http"
"strings"
"code.gitea.io/gitea/modules/cache"
"code.gitea.io/gitea/modules/gtprof"
"code.gitea.io/gitea/modules/httplib"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/routing"
@@ -63,8 +63,8 @@ func RequestContextHandler() func(h http.Handler) http.Handler {
}()
defer func() {
if err := recover(); err != nil {
RenderPanicErrorPage(respWriter, req, err) // it should never panic
if recovered := recover(); recovered != nil {
renderPanicErrorPage(respWriter, req, recovered) // it should never panic, and it handles the stack trace internally
}
}()
@@ -130,7 +130,7 @@ func MustInitSessioner() func(next http.Handler) http.Handler {
Domain: setting.SessionConfig.Domain,
})
if err != nil {
log.Fatalf("common.Sessioner failed: %v", err)
log.Fatal("common.Sessioner failed: %v", err)
}
return middleware
}

View File

@@ -180,8 +180,9 @@ func InitWebInstalled(ctx context.Context) {
// NormalRoutes represents non install routes
func NormalRoutes() *web.Router {
r := web.NewRouter()
r.Use(common.ProtocolMiddlewares()...)
r.Use(common.MaintenanceModeHandler())
r.BeforeRouting(common.ProtocolMiddlewares()...)
r.AfterRouting(common.MaintenanceModeHandler())
r.Mount("/", web_routers.Routes())
r.Mount("/api/v1", apiv1.Routes())

View File

@@ -20,11 +20,12 @@ import (
// Routes registers the installation routes
func Routes() *web.Router {
base := web.NewRouter()
base.Use(common.ProtocolMiddlewares()...)
base.BeforeRouting(common.ProtocolMiddlewares()...)
base.Methods("GET, HEAD", "/assets/*", public.FileHandlerFunc())
r := web.NewRouter()
r.Use(common.MustInitSessioner(), installContexter())
r.AfterRouting(common.MustInitSessioner(), installContexter())
r.Get("/", Install) // it must be on the root, because the "install.js" use the window.location to replace the "localhost" AppURL
r.Post("/", web.Bind(forms.InstallForm{}), SubmitInstall)

View File

@@ -54,11 +54,11 @@ func bind[T any](_ T) any {
// These APIs will be invoked by internal commands for example `gitea serv` and etc.
func Routes() *web.Router {
r := web.NewRouter()
r.Use(context.PrivateContexter())
r.Use(authInternal)
r.AfterRouting(context.PrivateContexter())
r.AfterRouting(authInternal)
// Log the real ip address of the request from SSH is really helpful for diagnosing sometimes.
// Since internal API will be sent only from Gitea sub commands and it's under control (checked by InternalToken), we can trust the headers.
r.Use(chi_middleware.RealIP)
r.AfterRouting(chi_middleware.RealIP)
r.Get("/dummy", misc.DummyOK)
r.Post("/ssh/authorized_keys", AuthorizedPublicKeyByContent)

View File

@@ -6,12 +6,9 @@ package web
import (
"code.gitea.io/gitea/modules/web"
"code.gitea.io/gitea/routers/web/repo"
"code.gitea.io/gitea/services/context"
)
func addOwnerRepoGitHTTPRouters(m *web.Router) {
// Some users want to use "web-based git client" to access Gitea's repositories,
// so the CORS handler and OPTIONS method are used.
func addOwnerRepoGitHTTPRouters(m *web.Router, middlewares ...any) {
m.Group("/{username}/{reponame}", func() {
m.Methods("POST,OPTIONS", "/git-upload-pack", repo.ServiceUploadPack)
m.Methods("POST,OPTIONS", "/git-receive-pack", repo.ServiceReceivePack)
@@ -25,5 +22,5 @@ func addOwnerRepoGitHTTPRouters(m *web.Router) {
m.Methods("GET,OPTIONS", "/objects/{head:[0-9a-f]{2}}/{hash:[0-9a-f]{38,62}}", repo.GetLooseObject)
m.Methods("GET,OPTIONS", "/objects/pack/pack-{file:[0-9a-f]{40,64}}.pack", repo.GetPackFile)
m.Methods("GET,OPTIONS", "/objects/pack/pack-{file:[0-9a-f]{40,64}}.idx", repo.GetIdxFile)
}, repo.HTTPGitEnabledHandler, repo.CorsHandler(), optSignInFromAnyOrigin, context.UserAssignmentWeb())
}, middlewares...)
}

View File

@@ -15,6 +15,7 @@ import (
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/metrics"
"code.gitea.io/gitea/modules/public"
"code.gitea.io/gitea/modules/reqctx"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/storage"
"code.gitea.io/gitea/modules/structs"
@@ -90,32 +91,64 @@ func optionsCorsHandler() func(next http.Handler) http.Handler {
}
}
// The OAuth2 plugin is expected to be executed first, as it must ignore the user id stored
// in the session (if there is a user id stored in session other plugins might return the user
// object for that id).
//
// The Session plugin is expected to be executed second, in order to skip authentication
// for users that have already signed in.
func buildAuthGroup() *auth_service.Group {
group := auth_service.NewGroup()
group.Add(&auth_service.OAuth2{}) // FIXME: this should be removed and only applied in download and oauth related routers
group.Add(&auth_service.Basic{}) // FIXME: this should be removed and only applied in download and git/lfs routers
if setting.Service.EnableReverseProxyAuth {
group.Add(&auth_service.ReverseProxy{}) // reverse-proxy should before Session, otherwise the header will be ignored if user has login
}
group.Add(&auth_service.Session{})
if setting.IsWindows && auth_model.IsSSPIEnabled(graceful.GetManager().ShutdownContext()) {
group.Add(&auth_service.SSPI{}) // it MUST be the last, see the comment of SSPI
}
return group
type AuthMiddleware struct {
AllowOAuth2 web.PreMiddlewareProvider
AllowBasic web.PreMiddlewareProvider
MiddlewareHandler func(*context.Context)
}
func webAuth(authMethod auth_service.Method) func(*context.Context) {
return func(ctx *context.Context) {
ar, err := common.AuthShared(ctx.Base, ctx.Session, authMethod)
func newWebAuthMiddleware() *AuthMiddleware {
type keyAllowOAuth2 struct{}
type keyAllowBasic struct{}
webAuth := &AuthMiddleware{}
middlewareSetContextValue := func(key, val any) web.PreMiddlewareProvider {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
dataStore := reqctx.GetRequestDataStore(r.Context())
dataStore.SetContextValue(key, val)
next.ServeHTTP(w, r)
})
}
}
webAuth.AllowBasic = middlewareSetContextValue(keyAllowBasic{}, true)
webAuth.AllowOAuth2 = middlewareSetContextValue(keyAllowOAuth2{}, true)
enableSSPI := setting.IsWindows && auth_model.IsSSPIEnabled(graceful.GetManager().ShutdownContext())
webAuth.MiddlewareHandler = func(ctx *context.Context) {
allowBasic := ctx.GetContextValue(keyAllowBasic{}) == true
allowOAuth2 := ctx.GetContextValue(keyAllowOAuth2{}) == true
group := auth_service.NewGroup()
// Most auth methods should ignore the user id stored in the session.
// If the auth succeeds, it must use the user id from the auth method to make sure the new login succeeds.
if allowOAuth2 {
group.Add(&auth_service.OAuth2{})
}
if allowBasic {
group.Add(&auth_service.Basic{})
}
// Sessionless means the route's auth can be done without web ui, then it doesn't need to create a session
// For example: accessing git via http, access rss feeds, downloading attachments, etc
isSessionless := allowOAuth2 || allowBasic
if setting.Service.EnableReverseProxyAuth {
// reverse-proxy should before Session, otherwise the header will be ignored if user has login
group.Add(&auth_service.ReverseProxy{CreateSession: !isSessionless})
}
// The Session plugin will skip authentication for users that have already signed in.
group.Add(&auth_service.Session{})
if enableSSPI {
// it MUST be the last, see the comment of SSPI
group.Add(&auth_service.SSPI{CreateSession: !isSessionless})
}
ar, err := common.AuthShared(ctx.Base, ctx.Session, group)
if err != nil {
log.Error("Failed to verify user: %v", err)
ctx.HTTPError(http.StatusUnauthorized, "Failed to authenticate user")
@@ -129,6 +162,7 @@ func webAuth(authMethod auth_service.Method) func(*context.Context) {
_ = ctx.Session.Delete("uid")
}
}
return webAuth
}
// verifyAuthWithOptions checks authentication according to options
@@ -223,6 +257,9 @@ const RouterMockPointBeforeWebRoutes = "before-web-routes"
func Routes() *web.Router {
routes := web.NewRouter()
// GetHead allows a HEAD request redirect to GET if HEAD method is not defined for that route
routes.BeforeRouting(chi_middleware.GetHead)
routes.Head("/", misc.DummyOK) // for health check - doesn't need to be passed through gzip handler
routes.Methods("GET, HEAD, OPTIONS", "/assets/*", optionsCorsHandler(), public.FileHandlerFunc())
routes.Methods("GET, HEAD", "/avatars/*", avatarStorageHandler(setting.Avatar.Storage, "avatars", storage.Avatars))
@@ -260,10 +297,8 @@ func Routes() *web.Router {
mid = append(mid, common.MustInitSessioner(), context.Contexter())
// Get user from session if logged in.
mid = append(mid, webAuth(buildAuthGroup()))
// GetHead allows a HEAD request redirect to GET if HEAD method is not defined for that route
mid = append(mid, chi_middleware.GetHead)
webAuth := newWebAuthMiddleware()
mid = append(mid, webAuth.MiddlewareHandler)
if setting.API.EnableSwagger {
// Note: The route is here but no in API routes because it renders a web page
@@ -272,10 +307,12 @@ func Routes() *web.Router {
mid = append(mid, goGet)
mid = append(mid, common.PageGlobalData)
mid = append(mid, common.BlockExpensive(), common.QoS(), web.RouterMockPoint(RouterMockPointBeforeWebRoutes))
webRoutes := web.NewRouter()
webRoutes.Use(mid...)
webRoutes.Group("", func() { registerWebRoutes(webRoutes) }, common.BlockExpensive(), common.QoS(), web.RouterMockPoint(RouterMockPointBeforeWebRoutes))
webRoutes.AfterRouting(mid...)
registerWebRoutes(webRoutes, webAuth)
routes.Mount("", webRoutes)
return routes
}
@@ -288,7 +325,7 @@ func Routes() *web.Router {
var optSignInFromAnyOrigin = verifyAuthWithOptions(&common.VerifyOptions{DisableCrossOriginProtection: true})
// registerWebRoutes register routes
func registerWebRoutes(m *web.Router) {
func registerWebRoutes(m *web.Router, webAuth *AuthMiddleware) {
// required to be signed in or signed out
reqSignIn := verifyAuthWithOptions(&common.VerifyOptions{SignInRequired: true})
reqSignOut := verifyAuthWithOptions(&common.VerifyOptions{SignOutRequired: true})
@@ -565,7 +602,7 @@ func registerWebRoutes(m *web.Router) {
m.Methods("POST, OPTIONS", "/access_token", web.Bind(forms.AccessTokenForm{}), auth.AccessTokenOAuth)
m.Methods("GET, OPTIONS", "/keys", auth.OIDCKeys)
m.Methods("POST, OPTIONS", "/introspect", web.Bind(forms.IntrospectTokenForm{}), auth.IntrospectOAuth)
}, optionsCorsHandler(), optSignInFromAnyOrigin)
}, optionsCorsHandler(), webAuth.AllowOAuth2, optSignInFromAnyOrigin)
}, oauth2Enabled)
m.Group("/user/settings", func() {
@@ -816,8 +853,9 @@ func registerWebRoutes(m *web.Router) {
// ***** END: Admin *****
m.Group("", func() {
m.Get("/{username}", user.UsernameSubRoute)
m.Methods("GET, OPTIONS", "/attachments/{uuid}", optionsCorsHandler(), repo.GetAttachment)
// it handles "username.rss" in the handler, so allow basic auth as other rss/atom routes
m.Get("/{username}", webAuth.AllowBasic, user.UsernameSubRoute)
m.Methods("GET, OPTIONS", "/attachments/{uuid}", optionsCorsHandler(), webAuth.AllowBasic, webAuth.AllowOAuth2, repo.GetAttachment)
}, optSignIn)
m.Post("/{username}", reqSignIn, context.UserAssignmentWeb(), user.ActionUserFollow)
@@ -1188,7 +1226,7 @@ func registerWebRoutes(m *web.Router) {
// end "/{username}/{reponame}/settings"
// user/org home, including rss feeds like "/{username}/{reponame}.rss"
m.Get("/{username}/{reponame}", optSignIn, context.RepoAssignment, context.RepoRefByType(git.RefTypeBranch), repo.SetEditorconfigIfExists, repo.Home)
m.Get("/{username}/{reponame}", optSignIn, webAuth.AllowBasic, context.RepoAssignment, context.RepoRefByType(git.RefTypeBranch), repo.SetEditorconfigIfExists, repo.Home)
m.Post("/{username}/{reponame}/markup", optSignIn, context.RepoAssignment, reqUnitsWithMarkdown, web.Bind(structs.MarkupOption{}), misc.Markup)
@@ -1389,8 +1427,8 @@ func registerWebRoutes(m *web.Router) {
m.Group("/{username}/{reponame}", func() { // repo tags
m.Group("/tags", func() {
m.Get("", context.RepoRefByDefaultBranch() /* for the "commits" tab */, repo.TagsList)
m.Get(".rss", feedEnabled, repo.TagsListFeedRSS)
m.Get(".atom", feedEnabled, repo.TagsListFeedAtom)
m.Get(".rss", webAuth.AllowBasic, feedEnabled, repo.TagsListFeedRSS)
m.Get(".atom", webAuth.AllowBasic, feedEnabled, repo.TagsListFeedAtom)
m.Get("/list", repo.GetTagList)
}, ctxDataSet("EnableFeed", setting.Other.EnableFeed))
m.Post("/tags/delete", reqSignIn, reqRepoCodeWriter, context.RepoMustNotBeArchived(), repo.DeleteTag)
@@ -1400,13 +1438,13 @@ func registerWebRoutes(m *web.Router) {
m.Group("/{username}/{reponame}", func() { // repo releases
m.Group("/releases", func() {
m.Get("", repo.Releases)
m.Get(".rss", feedEnabled, repo.ReleasesFeedRSS)
m.Get(".atom", feedEnabled, repo.ReleasesFeedAtom)
m.Get(".rss", webAuth.AllowBasic, feedEnabled, repo.ReleasesFeedRSS)
m.Get(".atom", webAuth.AllowBasic, feedEnabled, repo.ReleasesFeedAtom)
m.Get("/tag/*", repo.SingleRelease)
m.Get("/latest", repo.LatestRelease)
}, ctxDataSet("EnableFeed", setting.Other.EnableFeed))
m.Get("/releases/attachments/{uuid}", repo.GetAttachment)
m.Get("/releases/download/{vTag}/{fileName}", repo.RedirectDownload)
m.Get("/releases/attachments/{uuid}", webAuth.AllowBasic, webAuth.AllowOAuth2, repo.GetAttachment)
m.Get("/releases/download/{vTag}/{fileName}", webAuth.AllowBasic, webAuth.AllowOAuth2, repo.RedirectDownload)
m.Group("/releases", func() {
m.Get("/new", repo.NewRelease)
m.Post("/new", web.Bind(forms.NewReleaseForm{}), repo.NewReleasePost)
@@ -1423,7 +1461,7 @@ func registerWebRoutes(m *web.Router) {
// end "/{username}/{reponame}": repo releases
m.Group("/{username}/{reponame}", func() { // to maintain compatibility with old attachments
m.Get("/attachments/{uuid}", repo.GetAttachment)
m.Get("/attachments/{uuid}", webAuth.AllowBasic, webAuth.AllowOAuth2, repo.GetAttachment)
}, optSignIn, context.RepoAssignment)
// end "/{username}/{reponame}": compatibility with old attachments
@@ -1492,7 +1530,7 @@ func registerWebRoutes(m *web.Router) {
m.Post("/rerun", reqRepoActionsWriter, actions.Rerun)
})
m.Group("/workflows/{workflow_name}", func() {
m.Get("/badge.svg", actions.GetWorkflowBadge)
m.Get("/badge.svg", webAuth.AllowBasic, webAuth.AllowOAuth2, actions.GetWorkflowBadge)
})
}, optSignIn, context.RepoAssignment, repo.MustBeNotEmpty, reqRepoActionsReader, actions.MustEnableActions)
// end "/{username}/{reponame}/actions"
@@ -1578,7 +1616,7 @@ func registerWebRoutes(m *web.Router) {
m.Group("/archive", func() {
m.Get("/*", repo.Download)
m.Post("/*", repo.InitiateDownload)
}, repo.MustBeNotEmpty, dlSourceEnabled)
}, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.MustBeNotEmpty, dlSourceEnabled)
m.Group("/branches", func() {
m.Get("/list", repo.GetBranchesList)
@@ -1591,7 +1629,7 @@ func registerWebRoutes(m *web.Router) {
m.Get("/tag/*", context.RepoRefByType(git.RefTypeTag), repo.SingleDownloadOrLFS)
m.Get("/commit/*", context.RepoRefByType(git.RefTypeCommit), repo.SingleDownloadOrLFS)
m.Get("/*", context.RepoRefByType(""), repo.SingleDownloadOrLFS) // "/*" route is deprecated, and kept for backward compatibility
}, repo.MustBeNotEmpty)
}, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.MustBeNotEmpty)
m.Group("/raw", func() {
m.Get("/blob/{sha}", repo.DownloadByID)
@@ -1599,7 +1637,7 @@ func registerWebRoutes(m *web.Router) {
m.Get("/tag/*", context.RepoRefByType(git.RefTypeTag), repo.SingleDownload)
m.Get("/commit/*", context.RepoRefByType(git.RefTypeCommit), repo.SingleDownload)
m.Get("/*", context.RepoRefByType(""), repo.SingleDownload) // "/*" route is deprecated, and kept for backward compatibility
}, repo.MustBeNotEmpty)
}, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.MustBeNotEmpty)
m.Group("/render", func() {
m.Get("/branch/*", context.RepoRefByType(git.RefTypeBranch), repo.RenderFile)
@@ -1632,8 +1670,8 @@ func registerWebRoutes(m *web.Router) {
m.Get("/cherry-pick/{sha:([a-f0-9]{7,64})$}", repo.SetEditorconfigIfExists, context.RepoRefByDefaultBranch(), repo.CherryPick)
}, repo.MustBeNotEmpty)
m.Get("/rss/branch/*", context.RepoRefByType(git.RefTypeBranch), feedEnabled, feed.RenderBranchFeedRSS)
m.Get("/atom/branch/*", context.RepoRefByType(git.RefTypeBranch), feedEnabled, feed.RenderBranchFeedAtom)
m.Get("/rss/branch/*", context.RepoRefByType(git.RefTypeBranch), webAuth.AllowBasic, feedEnabled, feed.RenderBranchFeedRSS)
m.Get("/atom/branch/*", context.RepoRefByType(git.RefTypeBranch), webAuth.AllowBasic, feedEnabled, feed.RenderBranchFeedAtom)
m.Group("/src", func() {
m.Get("", func(ctx *context.Context) { ctx.Redirect(ctx.Repo.RepoLink) }) // there is no "{owner}/{repo}/src" page, so redirect to "{owner}/{repo}" to avoid 404
@@ -1660,9 +1698,14 @@ func registerWebRoutes(m *web.Router) {
m.Post("/action/{action:accept_transfer|reject_transfer}", reqSignIn, repo.ActionTransfer)
}, optSignIn, context.RepoAssignment)
common.AddOwnerRepoGitLFSRoutes(m, lfsServerEnabled, repo.CorsHandler(), optSignInFromAnyOrigin) // "/{username}/{reponame}/{lfs-paths}": git-lfs support, see also addOwnerRepoGitHTTPRouters
// git lfs uses its own jwt key, and it handles the token & auth by itself, it conflicts with the general "OAuth2" auth method
// pattern: "/{username}/{reponame}/{lfs-paths}": git-lfs support, see also addOwnerRepoGitHTTPRouters
common.AddOwnerRepoGitLFSRoutes(m, lfsServerEnabled, webAuth.AllowBasic, repo.CorsHandler(), optSignInFromAnyOrigin)
addOwnerRepoGitHTTPRouters(m) // "/{username}/{reponame}/{git-paths}": git http support
// Some users want to use "web-based git client" to access Gitea's repositories,
// so the CORS handler and OPTIONS method are used.
// pattern: "/{username}/{reponame}/{git-paths}": git http support
addOwnerRepoGitHTTPRouters(m, repo.HTTPGitEnabledHandler, webAuth.AllowBasic, webAuth.AllowOAuth2, repo.CorsHandler(), optSignInFromAnyOrigin, context.UserAssignmentWeb())
m.Group("/notifications", func() {
m.Get("", user.Notifications)

View File

@@ -8,38 +8,16 @@ import (
"errors"
"fmt"
"net/http"
"regexp"
"strings"
"sync"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/auth/webauthn"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/optional"
"code.gitea.io/gitea/modules/session"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/web/middleware"
user_service "code.gitea.io/gitea/services/user"
)
type globalVarsStruct struct {
gitRawOrAttachPathRe *regexp.Regexp
lfsPathRe *regexp.Regexp
archivePathRe *regexp.Regexp
feedPathRe *regexp.Regexp
feedRefPathRe *regexp.Regexp
}
var globalVars = sync.OnceValue(func() *globalVarsStruct {
return &globalVarsStruct{
gitRawOrAttachPathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/(?:(?:git-(?:(?:upload)|(?:receive))-pack$)|(?:info/refs$)|(?:HEAD$)|(?:objects/)|(?:raw/)|(?:releases/download/)|(?:attachments/))`),
lfsPathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/info/lfs/`),
archivePathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/archive/`),
feedPathRe: regexp.MustCompile(`^/[-.\w]+(/[-.\w]+)?\.(rss|atom)$`), // "/owner.rss" or "/owner/repo.atom"
feedRefPathRe: regexp.MustCompile(`^/[-.\w]+/[-.\w]+/(rss|atom)/`), // "/owner/repo/rss/branch/..."
}
})
type ErrUserAuthMessage string
func (e ErrUserAuthMessage) Error() string {
@@ -60,66 +38,6 @@ func Init() {
webauthn.Init()
}
type authPathDetector struct {
req *http.Request
vars *globalVarsStruct
}
func newAuthPathDetector(req *http.Request) *authPathDetector {
return &authPathDetector{req: req, vars: globalVars()}
}
// isAPIPath returns true if the specified URL is an API path
func (a *authPathDetector) isAPIPath() bool {
return strings.HasPrefix(a.req.URL.Path, "/api/")
}
// isAttachmentDownload check if request is a file download (GET) with URL to an attachment
func (a *authPathDetector) isAttachmentDownload() bool {
return strings.HasPrefix(a.req.URL.Path, "/attachments/") && a.req.Method == http.MethodGet
}
func (a *authPathDetector) isFeedRequest(req *http.Request) bool {
if !setting.Other.EnableFeed {
return false
}
if req.Method != http.MethodGet {
return false
}
return a.vars.feedPathRe.MatchString(req.URL.Path) || a.vars.feedRefPathRe.MatchString(req.URL.Path)
}
// isContainerPath checks if the request targets the container endpoint
func (a *authPathDetector) isContainerPath() bool {
return strings.HasPrefix(a.req.URL.Path, "/v2/")
}
func (a *authPathDetector) isGitRawOrAttachPath() bool {
return a.vars.gitRawOrAttachPathRe.MatchString(a.req.URL.Path)
}
func (a *authPathDetector) isGitRawOrAttachOrLFSPath() bool {
if a.isGitRawOrAttachPath() {
return true
}
if setting.LFS.StartServer {
return a.vars.lfsPathRe.MatchString(a.req.URL.Path)
}
return false
}
func (a *authPathDetector) isArchivePath() bool {
return a.vars.archivePathRe.MatchString(a.req.URL.Path)
}
func (a *authPathDetector) isAuthenticatedTokenRequest() bool {
switch a.req.URL.Path {
case "/login/oauth/userinfo", "/login/oauth/introspect":
return true
}
return false
}
// handleSignIn clears existing session variables and stores new ones for the specified user object
func handleSignIn(resp http.ResponseWriter, req *http.Request, sess SessionStore, user *user_model.User) {
// We need to regenerate the session...

View File

@@ -1,155 +0,0 @@
// Copyright 2014 The Gogs Authors. All rights reserved.
// Copyright 2019 The Gitea Authors. All rights reserved.
// SPDX-License-Identifier: MIT
package auth
import (
"net/http"
"testing"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/test"
"github.com/stretchr/testify/assert"
)
func Test_isGitRawOrLFSPath(t *testing.T) {
tests := []struct {
path string
want bool
}{
{
"/owner/repo/git-upload-pack",
true,
},
{
"/owner/repo/git-receive-pack",
true,
},
{
"/owner/repo/info/refs",
true,
},
{
"/owner/repo/HEAD",
true,
},
{
"/owner/repo/objects/info/alternates",
true,
},
{
"/owner/repo/objects/info/http-alternates",
true,
},
{
"/owner/repo/objects/info/packs",
true,
},
{
"/owner/repo/objects/info/blahahsdhsdkla",
true,
},
{
"/owner/repo/objects/01/23456789abcdef0123456789abcdef01234567",
true,
},
{
"/owner/repo/objects/pack/pack-123456789012345678921234567893124567894.pack",
true,
},
{
"/owner/repo/objects/pack/pack-0123456789abcdef0123456789abcdef0123456.idx",
true,
},
{
"/owner/repo/raw/branch/foo/fanaso",
true,
},
{
"/owner/repo/stars",
false,
},
{
"/notowner",
false,
},
{
"/owner/repo",
false,
},
{
"/owner/repo/commit/123456789012345678921234567893124567894",
false,
},
{
"/owner/repo/releases/download/tag/repo.tar.gz",
true,
},
{
"/owner/repo/attachments/6d92a9ee-5d8b-4993-97c9-6181bdaa8955",
true,
},
}
defer test.MockVariableValue(&setting.LFS.StartServer)()
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, "http://localhost"+tt.path, nil)
setting.LFS.StartServer = false
assert.Equal(t, tt.want, newAuthPathDetector(req).isGitRawOrAttachOrLFSPath())
setting.LFS.StartServer = true
assert.Equal(t, tt.want, newAuthPathDetector(req).isGitRawOrAttachOrLFSPath())
})
}
lfsTests := []string{
"/owner/repo/info/lfs/",
"/owner/repo/info/lfs/objects/batch",
"/owner/repo/info/lfs/objects/oid/filename",
"/owner/repo/info/lfs/objects/oid",
"/owner/repo/info/lfs/objects",
"/owner/repo/info/lfs/verify",
"/owner/repo/info/lfs/locks",
"/owner/repo/info/lfs/locks/verify",
"/owner/repo/info/lfs/locks/123/unlock",
}
for _, tt := range lfsTests {
t.Run(tt, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodPost, tt, nil)
setting.LFS.StartServer = false
got := newAuthPathDetector(req).isGitRawOrAttachOrLFSPath()
assert.Equalf(t, setting.LFS.StartServer, got, "isGitOrLFSPath(%q) = %v, want %v, %v", tt, got, setting.LFS.StartServer, globalVars().gitRawOrAttachPathRe.MatchString(tt))
setting.LFS.StartServer = true
got = newAuthPathDetector(req).isGitRawOrAttachOrLFSPath()
assert.Equalf(t, setting.LFS.StartServer, got, "isGitOrLFSPath(%q) = %v, want %v", tt, got, setting.LFS.StartServer)
})
}
}
func Test_isFeedRequest(t *testing.T) {
tests := []struct {
want bool
path string
}{
{true, "/user.rss"},
{true, "/user/repo.atom"},
{false, "/user/repo"},
{false, "/use/repo/file.rss"},
{true, "/org/repo/rss/branch/xxx"},
{true, "/org/repo/atom/tag/xxx"},
{false, "/org/repo/branch/main/rss/any"},
{false, "/org/atom/any"},
}
for _, tt := range tests {
t.Run(tt.path, func(t *testing.T) {
req, _ := http.NewRequest(http.MethodGet, "http://localhost"+tt.path, nil)
assert.Equal(t, tt.want, newAuthPathDetector(req).isFeedRequest(req))
})
}
}

View File

@@ -41,13 +41,6 @@ func (b *Basic) Name() string {
}
func (b *Basic) parseAuthBasic(req *http.Request) (ret struct{ authToken, uname, passwd string }) {
// Basic authentication should only fire on API, Feed, Download, Archives or on Git or LFSPaths
// Not all feed (rss/atom) clients feature the ability to add cookies or headers, so we need to allow basic auth for feeds
detector := newAuthPathDetector(req)
if !detector.isAPIPath() && !detector.isFeedRequest(req) && !detector.isContainerPath() && !detector.isAttachmentDownload() && !detector.isArchivePath() && !detector.isGitRawOrAttachOrLFSPath() {
return ret
}
authHeader := req.Header.Get("Authorization")
if authHeader == "" {
return ret

View File

@@ -152,13 +152,6 @@ func (o *OAuth2) userFromToken(ctx context.Context, tokenSHA string, store DataS
// If verification is successful returns an existing user object.
// Returns nil if verification fails.
func (o *OAuth2) Verify(req *http.Request, w http.ResponseWriter, store DataStore, sess SessionStore) (*user_model.User, error) {
// These paths are not API paths, but we still want to check for tokens because they maybe in the API returned URLs
detector := newAuthPathDetector(req)
if !detector.isAPIPath() && !detector.isAttachmentDownload() && !detector.isAuthenticatedTokenRequest() &&
!detector.isGitRawOrAttachPath() && !detector.isArchivePath() {
return nil, nil //nolint:nilnil // the auth method is not applicable
}
token, ok := parseToken(req)
if !ok {
return nil, nil //nolint:nilnil // the auth method is not applicable

View File

@@ -29,7 +29,9 @@ const ReverseProxyMethodName = "reverse_proxy"
// On successful authentication the proxy is expected to populate the username in the
// "setting.ReverseProxyAuthUser" header. Optionally it can also populate the email of the
// user in the "setting.ReverseProxyAuthEmail" header.
type ReverseProxy struct{}
type ReverseProxy struct {
CreateSession bool
}
// getUserName extracts the username from the "setting.ReverseProxyAuthUser" header
func (r *ReverseProxy) getUserName(req *http.Request) string {
@@ -115,9 +117,7 @@ func (r *ReverseProxy) Verify(req *http.Request, w http.ResponseWriter, store Da
}
}
// Make sure requests to API paths, attachment downloads, git and LFS do not create a new session
detector := newAuthPathDetector(req)
if !detector.isAPIPath() && !detector.isAttachmentDownload() && !detector.isGitRawOrAttachOrLFSPath() {
if r.CreateSession {
if sess != nil && (sess.Get("uid") == nil || sess.Get("uid").(int64) != user.ID) {
handleSignIn(w, req, sess, user)
}

View File

@@ -46,7 +46,9 @@ var (
// The SSPI plugin is expected to be executed last, as it returns 401 status code if negotiation
// fails (or if negotiation should continue), which would prevent other authentication methods
// to execute at all.
type SSPI struct{}
type SSPI struct {
CreateSession bool
}
// Name represents the name of auth method
func (s *SSPI) Name() string {
@@ -118,9 +120,7 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore,
}
}
// Make sure requests to API paths and PWA resources do not create a new session
detector := newAuthPathDetector(req)
if !detector.isAPIPath() && !detector.isAttachmentDownload() {
if s.CreateSession {
handleSignIn(w, req, sess, user)
}
@@ -147,18 +147,9 @@ func (s *SSPI) getConfig(ctx context.Context) (*sspi.Source, error) {
}
func (s *SSPI) shouldAuthenticate(req *http.Request) (shouldAuth bool) {
shouldAuth = false
path := strings.TrimSuffix(req.URL.Path, "/")
if path == "/user/login" {
if req.FormValue("user_name") != "" && req.FormValue("password") != "" {
shouldAuth = false
} else if req.FormValue("auth_with_sspi") == "1" {
shouldAuth = true
}
} else {
detector := newAuthPathDetector(req)
shouldAuth = detector.isAPIPath() || detector.isAttachmentDownload()
}
// SSPI is only applicable for login requests with "auth_with_sspi" form value set to "1"
// See the template code with "auth_with_sspi"
shouldAuth = req.URL.Path == "/user/login" && req.FormValue("auth_with_sspi") == "1"
return shouldAuth
}

View File

@@ -22,6 +22,7 @@ import (
"code.gitea.io/gitea/models/unittest"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/gitrepo"
"code.gitea.io/gitea/modules/httplib"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/structs"
"code.gitea.io/gitea/services/migrations"
@@ -126,7 +127,7 @@ func Test_MigrateFromGiteaToGitea(t *testing.T) {
session := loginUser(t, owner.Name)
token := getTokenForLoggedInUser(t, session, auth_model.AccessTokenScopeAll)
resp, err := http.Get("https://gitea.com/gitea")
resp, err := httplib.NewRequest("https://gitea.com/gitea/test_repo.git", "GET").SetReadWriteTimeout(5 * time.Second).Response()
if err != nil || resp.StatusCode != http.StatusOK {
if resp != nil {
resp.Body.Close()