mirror of
https://github.com/caddyserver/caddy.git
synced 2026-03-17 14:34:03 +00:00
Wire up Cause for most context cancels (#7538)
This commit is contained in:
6
admin.go
6
admin.go
@@ -749,10 +749,14 @@ func stopAdminServer(srv *http.Server) error {
|
|||||||
if srv == nil {
|
if srv == nil {
|
||||||
return fmt.Errorf("no admin server")
|
return fmt.Errorf("no admin server")
|
||||||
}
|
}
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
timeout := 10 * time.Second
|
||||||
|
ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fmt.Errorf("stopping admin server: %ds timeout", int(timeout.Seconds())))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
err := srv.Shutdown(ctx)
|
err := srv.Shutdown(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
err = cause
|
||||||
|
}
|
||||||
return fmt.Errorf("shutting down admin server: %v", err)
|
return fmt.Errorf("shutting down admin server: %v", err)
|
||||||
}
|
}
|
||||||
Log().Named("admin").Info("stopped previous server", zap.String("address", srv.Addr))
|
Log().Named("admin").Info("stopped previous server", zap.String("address", srv.Addr))
|
||||||
|
|||||||
14
caddy.go
14
caddy.go
@@ -88,7 +88,7 @@ type Config struct {
|
|||||||
storage certmagic.Storage
|
storage certmagic.Storage
|
||||||
eventEmitter eventEmitter
|
eventEmitter eventEmitter
|
||||||
|
|
||||||
cancelFunc context.CancelFunc
|
cancelFunc context.CancelCauseFunc
|
||||||
|
|
||||||
// fileSystems is a dict of fileSystems that will later be loaded from and added to.
|
// fileSystems is a dict of fileSystems that will later be loaded from and added to.
|
||||||
fileSystems FileSystems
|
fileSystems FileSystems
|
||||||
@@ -433,7 +433,7 @@ func run(newCfg *Config, start bool) (Context, error) {
|
|||||||
// partially copied from provisionContext
|
// partially copied from provisionContext
|
||||||
if err != nil {
|
if err != nil {
|
||||||
globalMetrics.configSuccess.Set(0)
|
globalMetrics.configSuccess.Set(0)
|
||||||
ctx.cfg.cancelFunc()
|
ctx.cfg.cancelFunc(fmt.Errorf("configuration start error: %w", err))
|
||||||
|
|
||||||
if currentCtx.cfg != nil {
|
if currentCtx.cfg != nil {
|
||||||
certmagic.Default.Storage = currentCtx.cfg.storage
|
certmagic.Default.Storage = currentCtx.cfg.storage
|
||||||
@@ -509,7 +509,7 @@ func provisionContext(newCfg *Config, replaceAdminServer bool) (Context, error)
|
|||||||
// cleanup occurs when we return if there
|
// cleanup occurs when we return if there
|
||||||
// was an error; if no error, it will get
|
// was an error; if no error, it will get
|
||||||
// cleaned up on next config cycle
|
// cleaned up on next config cycle
|
||||||
ctx, cancel := NewContext(Context{Context: context.Background(), cfg: newCfg})
|
ctx, cancelCause := NewContextWithCause(Context{Context: context.Background(), cfg: newCfg})
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
globalMetrics.configSuccess.Set(0)
|
globalMetrics.configSuccess.Set(0)
|
||||||
@@ -518,7 +518,7 @@ func provisionContext(newCfg *Config, replaceAdminServer bool) (Context, error)
|
|||||||
// since the associated config won't be used;
|
// since the associated config won't be used;
|
||||||
// this will cause all modules that were newly
|
// this will cause all modules that were newly
|
||||||
// provisioned to clean themselves up
|
// provisioned to clean themselves up
|
||||||
cancel()
|
cancelCause(fmt.Errorf("configuration error: %w", err))
|
||||||
|
|
||||||
// also undo any other state changes we made
|
// also undo any other state changes we made
|
||||||
if currentCtx.cfg != nil {
|
if currentCtx.cfg != nil {
|
||||||
@@ -526,7 +526,7 @@ func provisionContext(newCfg *Config, replaceAdminServer bool) (Context, error)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
newCfg.cancelFunc = cancel // clean up later
|
newCfg.cancelFunc = cancelCause // clean up later
|
||||||
|
|
||||||
// set up logging before anything bad happens
|
// set up logging before anything bad happens
|
||||||
if newCfg.Logging == nil {
|
if newCfg.Logging == nil {
|
||||||
@@ -746,7 +746,7 @@ func unsyncedStop(ctx Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// clean up all modules
|
// clean up all modules
|
||||||
ctx.cfg.cancelFunc()
|
ctx.cfg.cancelFunc(fmt.Errorf("stopping apps"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Validate loads, provisions, and validates
|
// Validate loads, provisions, and validates
|
||||||
@@ -754,7 +754,7 @@ func unsyncedStop(ctx Context) {
|
|||||||
func Validate(cfg *Config) error {
|
func Validate(cfg *Config) error {
|
||||||
_, err := run(cfg, false)
|
_, err := run(cfg, false)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
cfg.cancelFunc() // call Cleanup on all modules
|
cfg.cancelFunc(fmt.Errorf("validation complete")) // call Cleanup on all modules
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
13
context.go
13
context.go
@@ -63,10 +63,17 @@ type Context struct {
|
|||||||
// modules which are loaded will be properly unloaded.
|
// modules which are loaded will be properly unloaded.
|
||||||
// See standard library context package's documentation.
|
// See standard library context package's documentation.
|
||||||
func NewContext(ctx Context) (Context, context.CancelFunc) {
|
func NewContext(ctx Context) (Context, context.CancelFunc) {
|
||||||
|
newCtx, cancelCause := NewContextWithCause(ctx)
|
||||||
|
return newCtx, func() { cancelCause(nil) }
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewContextWithCause is like NewContext but returns a context.CancelCauseFunc.
|
||||||
|
// EXPERIMENTAL: This API is subject to change.
|
||||||
|
func NewContextWithCause(ctx Context) (Context, context.CancelCauseFunc) {
|
||||||
newCtx := Context{moduleInstances: make(map[string][]Module), cfg: ctx.cfg, metricsRegistry: prometheus.NewPedanticRegistry()}
|
newCtx := Context{moduleInstances: make(map[string][]Module), cfg: ctx.cfg, metricsRegistry: prometheus.NewPedanticRegistry()}
|
||||||
c, cancel := context.WithCancel(ctx.Context)
|
c, cancel := context.WithCancelCause(ctx.Context)
|
||||||
wrappedCancel := func() {
|
wrappedCancel := func(cause error) {
|
||||||
cancel()
|
cancel(cause)
|
||||||
|
|
||||||
for _, f := range ctx.cleanupFuncs {
|
for _, f := range ctx.cleanupFuncs {
|
||||||
f()
|
f()
|
||||||
|
|||||||
18
listeners.go
18
listeners.go
@@ -512,7 +512,7 @@ func ListenerUsage(network, addr string) int {
|
|||||||
// contextAndCancelFunc groups context and its cancelFunc
|
// contextAndCancelFunc groups context and its cancelFunc
|
||||||
type contextAndCancelFunc struct {
|
type contextAndCancelFunc struct {
|
||||||
context.Context
|
context.Context
|
||||||
context.CancelFunc
|
context.CancelCauseFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// sharedQUICState manages GetConfigForClient
|
// sharedQUICState manages GetConfigForClient
|
||||||
@@ -542,17 +542,17 @@ func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Co
|
|||||||
|
|
||||||
// addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc
|
// addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc
|
||||||
// so that when cancelled, the active tls.Config will change
|
// so that when cancelled, the active tls.Config will change
|
||||||
func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelFunc) {
|
func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelCauseFunc) {
|
||||||
sqs.rmu.Lock()
|
sqs.rmu.Lock()
|
||||||
defer sqs.rmu.Unlock()
|
defer sqs.rmu.Unlock()
|
||||||
|
|
||||||
if cacc, ok := sqs.tlsConfs[tlsConfig]; ok {
|
if cacc, ok := sqs.tlsConfs[tlsConfig]; ok {
|
||||||
return cacc.Context, cacc.CancelFunc
|
return cacc.Context, cacc.CancelCauseFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancelCause(context.Background())
|
||||||
wrappedCancel := func() {
|
wrappedCancel := func(cause error) {
|
||||||
cancel()
|
cancel(cause)
|
||||||
|
|
||||||
sqs.rmu.Lock()
|
sqs.rmu.Lock()
|
||||||
defer sqs.rmu.Unlock()
|
defer sqs.rmu.Unlock()
|
||||||
@@ -608,13 +608,13 @@ func fakeClosedErr(l interface{ Addr() net.Addr }) error {
|
|||||||
// indicating that it is pretending to be closed so that the
|
// indicating that it is pretending to be closed so that the
|
||||||
// server using it can terminate, while the underlying
|
// server using it can terminate, while the underlying
|
||||||
// socket is actually left open.
|
// socket is actually left open.
|
||||||
var errFakeClosed = fmt.Errorf("listener 'closed' 😉")
|
var errFakeClosed = fmt.Errorf("QUIC listener 'closed' 😉")
|
||||||
|
|
||||||
type fakeCloseQuicListener struct {
|
type fakeCloseQuicListener struct {
|
||||||
closed int32 // accessed atomically; belongs to this struct only
|
closed int32 // accessed atomically; belongs to this struct only
|
||||||
*sharedQuicListener // embedded, so we also become a quic.EarlyListener
|
*sharedQuicListener // embedded, so we also become a quic.EarlyListener
|
||||||
context context.Context
|
context context.Context
|
||||||
contextCancel context.CancelFunc
|
contextCancel context.CancelCauseFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
// Currently Accept ignores the passed context, however a situation where
|
// Currently Accept ignores the passed context, however a situation where
|
||||||
@@ -637,7 +637,7 @@ func (fcql *fakeCloseQuicListener) Accept(_ context.Context) (*quic.Conn, error)
|
|||||||
|
|
||||||
func (fcql *fakeCloseQuicListener) Close() error {
|
func (fcql *fakeCloseQuicListener) Close() error {
|
||||||
if atomic.CompareAndSwapInt32(&fcql.closed, 0, 1) {
|
if atomic.CompareAndSwapInt32(&fcql.closed, 0, 1) {
|
||||||
fcql.contextCancel()
|
fcql.contextCancel(errFakeClosed)
|
||||||
} else if atomic.CompareAndSwapInt32(&fcql.closed, 1, 2) {
|
} else if atomic.CompareAndSwapInt32(&fcql.closed, 1, 2) {
|
||||||
_, _ = listenerPool.Delete(fcql.sharedQuicListener.key)
|
_, _ = listenerPool.Delete(fcql.sharedQuicListener.key)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"cmp"
|
"cmp"
|
||||||
"context"
|
"context"
|
||||||
"crypto/tls"
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"maps"
|
"maps"
|
||||||
"net"
|
"net"
|
||||||
@@ -711,9 +712,10 @@ func (app *App) Stop() error {
|
|||||||
// enforce grace period if configured
|
// enforce grace period if configured
|
||||||
if app.GracePeriod > 0 {
|
if app.GracePeriod > 0 {
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithTimeout(ctx, time.Duration(app.GracePeriod))
|
timeout := time.Duration(app.GracePeriod)
|
||||||
|
ctx, cancel = context.WithTimeoutCause(ctx, timeout, fmt.Errorf("server graceful shutdown %ds timeout", int(timeout.Seconds())))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", time.Duration(app.GracePeriod)))
|
app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", timeout))
|
||||||
} else {
|
} else {
|
||||||
app.logger.Info("servers shutting down with eternal grace period")
|
app.logger.Info("servers shutting down with eternal grace period")
|
||||||
}
|
}
|
||||||
@@ -739,6 +741,9 @@ func (app *App) Stop() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := server.server.Shutdown(ctx); err != nil {
|
if err := server.server.Shutdown(ctx); err != nil {
|
||||||
|
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
err = cause
|
||||||
|
}
|
||||||
app.logger.Error("server shutdown",
|
app.logger.Error("server shutdown",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.Strings("addresses", server.Listen))
|
zap.Strings("addresses", server.Listen))
|
||||||
@@ -762,6 +767,9 @@ func (app *App) Stop() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := server.h3server.Shutdown(ctx); err != nil {
|
if err := server.h3server.Shutdown(ctx); err != nil {
|
||||||
|
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
err = cause
|
||||||
|
}
|
||||||
app.logger.Error("HTTP/3 server shutdown",
|
app.logger.Error("HTTP/3 server shutdown",
|
||||||
zap.Error(err),
|
zap.Error(err),
|
||||||
zap.Strings("addresses", server.Listen))
|
zap.Strings("addresses", server.Listen))
|
||||||
|
|||||||
@@ -448,7 +448,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
|
|||||||
// complete the handshake before returning the connection
|
// complete the handshake before returning the connection
|
||||||
if rt.TLSHandshakeTimeout != 0 {
|
if rt.TLSHandshakeTimeout != 0 {
|
||||||
var cancel context.CancelFunc
|
var cancel context.CancelFunc
|
||||||
ctx, cancel = context.WithTimeout(ctx, rt.TLSHandshakeTimeout)
|
ctx, cancel = context.WithTimeoutCause(ctx, rt.TLSHandshakeTimeout, fmt.Errorf("HTTP transport TLS handshake %ds timeout", int(rt.TLSHandshakeTimeout.Seconds())))
|
||||||
defer cancel()
|
defer cancel()
|
||||||
}
|
}
|
||||||
err = tlsConn.HandshakeContext(ctx)
|
err = tlsConn.HandshakeContext(ctx)
|
||||||
|
|||||||
Reference in New Issue
Block a user