Wire up Cause for most context cancels (#7538)

This commit is contained in:
Francis Lavoie
2026-03-04 19:14:52 -05:00
committed by GitHub
parent fbfb8fc517
commit 6e5e08cf58
6 changed files with 42 additions and 23 deletions

View File

@@ -749,10 +749,14 @@ func stopAdminServer(srv *http.Server) error {
if srv == nil { if srv == nil {
return fmt.Errorf("no admin server") return fmt.Errorf("no admin server")
} }
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) timeout := 10 * time.Second
ctx, cancel := context.WithTimeoutCause(context.Background(), timeout, fmt.Errorf("stopping admin server: %ds timeout", int(timeout.Seconds())))
defer cancel() defer cancel()
err := srv.Shutdown(ctx) err := srv.Shutdown(ctx)
if err != nil { if err != nil {
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
err = cause
}
return fmt.Errorf("shutting down admin server: %v", err) return fmt.Errorf("shutting down admin server: %v", err)
} }
Log().Named("admin").Info("stopped previous server", zap.String("address", srv.Addr)) Log().Named("admin").Info("stopped previous server", zap.String("address", srv.Addr))

View File

@@ -88,7 +88,7 @@ type Config struct {
storage certmagic.Storage storage certmagic.Storage
eventEmitter eventEmitter eventEmitter eventEmitter
cancelFunc context.CancelFunc cancelFunc context.CancelCauseFunc
// fileSystems is a dict of fileSystems that will later be loaded from and added to. // fileSystems is a dict of fileSystems that will later be loaded from and added to.
fileSystems FileSystems fileSystems FileSystems
@@ -433,7 +433,7 @@ func run(newCfg *Config, start bool) (Context, error) {
// partially copied from provisionContext // partially copied from provisionContext
if err != nil { if err != nil {
globalMetrics.configSuccess.Set(0) globalMetrics.configSuccess.Set(0)
ctx.cfg.cancelFunc() ctx.cfg.cancelFunc(fmt.Errorf("configuration start error: %w", err))
if currentCtx.cfg != nil { if currentCtx.cfg != nil {
certmagic.Default.Storage = currentCtx.cfg.storage certmagic.Default.Storage = currentCtx.cfg.storage
@@ -509,7 +509,7 @@ func provisionContext(newCfg *Config, replaceAdminServer bool) (Context, error)
// cleanup occurs when we return if there // cleanup occurs when we return if there
// was an error; if no error, it will get // was an error; if no error, it will get
// cleaned up on next config cycle // cleaned up on next config cycle
ctx, cancel := NewContext(Context{Context: context.Background(), cfg: newCfg}) ctx, cancelCause := NewContextWithCause(Context{Context: context.Background(), cfg: newCfg})
defer func() { defer func() {
if err != nil { if err != nil {
globalMetrics.configSuccess.Set(0) globalMetrics.configSuccess.Set(0)
@@ -518,7 +518,7 @@ func provisionContext(newCfg *Config, replaceAdminServer bool) (Context, error)
// since the associated config won't be used; // since the associated config won't be used;
// this will cause all modules that were newly // this will cause all modules that were newly
// provisioned to clean themselves up // provisioned to clean themselves up
cancel() cancelCause(fmt.Errorf("configuration error: %w", err))
// also undo any other state changes we made // also undo any other state changes we made
if currentCtx.cfg != nil { if currentCtx.cfg != nil {
@@ -526,7 +526,7 @@ func provisionContext(newCfg *Config, replaceAdminServer bool) (Context, error)
} }
} }
}() }()
newCfg.cancelFunc = cancel // clean up later newCfg.cancelFunc = cancelCause // clean up later
// set up logging before anything bad happens // set up logging before anything bad happens
if newCfg.Logging == nil { if newCfg.Logging == nil {
@@ -746,7 +746,7 @@ func unsyncedStop(ctx Context) {
} }
// clean up all modules // clean up all modules
ctx.cfg.cancelFunc() ctx.cfg.cancelFunc(fmt.Errorf("stopping apps"))
} }
// Validate loads, provisions, and validates // Validate loads, provisions, and validates
@@ -754,7 +754,7 @@ func unsyncedStop(ctx Context) {
func Validate(cfg *Config) error { func Validate(cfg *Config) error {
_, err := run(cfg, false) _, err := run(cfg, false)
if err == nil { if err == nil {
cfg.cancelFunc() // call Cleanup on all modules cfg.cancelFunc(fmt.Errorf("validation complete")) // call Cleanup on all modules
} }
return err return err
} }

View File

@@ -63,10 +63,17 @@ type Context struct {
// modules which are loaded will be properly unloaded. // modules which are loaded will be properly unloaded.
// See standard library context package's documentation. // See standard library context package's documentation.
func NewContext(ctx Context) (Context, context.CancelFunc) { func NewContext(ctx Context) (Context, context.CancelFunc) {
newCtx, cancelCause := NewContextWithCause(ctx)
return newCtx, func() { cancelCause(nil) }
}
// NewContextWithCause is like NewContext but returns a context.CancelCauseFunc.
// EXPERIMENTAL: This API is subject to change.
func NewContextWithCause(ctx Context) (Context, context.CancelCauseFunc) {
newCtx := Context{moduleInstances: make(map[string][]Module), cfg: ctx.cfg, metricsRegistry: prometheus.NewPedanticRegistry()} newCtx := Context{moduleInstances: make(map[string][]Module), cfg: ctx.cfg, metricsRegistry: prometheus.NewPedanticRegistry()}
c, cancel := context.WithCancel(ctx.Context) c, cancel := context.WithCancelCause(ctx.Context)
wrappedCancel := func() { wrappedCancel := func(cause error) {
cancel() cancel(cause)
for _, f := range ctx.cleanupFuncs { for _, f := range ctx.cleanupFuncs {
f() f()

View File

@@ -512,7 +512,7 @@ func ListenerUsage(network, addr string) int {
// contextAndCancelFunc groups context and its cancelFunc // contextAndCancelFunc groups context and its cancelFunc
type contextAndCancelFunc struct { type contextAndCancelFunc struct {
context.Context context.Context
context.CancelFunc context.CancelCauseFunc
} }
// sharedQUICState manages GetConfigForClient // sharedQUICState manages GetConfigForClient
@@ -542,17 +542,17 @@ func (sqs *sharedQUICState) getConfigForClient(ch *tls.ClientHelloInfo) (*tls.Co
// addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc // addState adds tls.Config and activeRequests to the map if not present and returns the corresponding context and its cancelFunc
// so that when cancelled, the active tls.Config will change // so that when cancelled, the active tls.Config will change
func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelFunc) { func (sqs *sharedQUICState) addState(tlsConfig *tls.Config) (context.Context, context.CancelCauseFunc) {
sqs.rmu.Lock() sqs.rmu.Lock()
defer sqs.rmu.Unlock() defer sqs.rmu.Unlock()
if cacc, ok := sqs.tlsConfs[tlsConfig]; ok { if cacc, ok := sqs.tlsConfs[tlsConfig]; ok {
return cacc.Context, cacc.CancelFunc return cacc.Context, cacc.CancelCauseFunc
} }
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancelCause(context.Background())
wrappedCancel := func() { wrappedCancel := func(cause error) {
cancel() cancel(cause)
sqs.rmu.Lock() sqs.rmu.Lock()
defer sqs.rmu.Unlock() defer sqs.rmu.Unlock()
@@ -608,13 +608,13 @@ func fakeClosedErr(l interface{ Addr() net.Addr }) error {
// indicating that it is pretending to be closed so that the // indicating that it is pretending to be closed so that the
// server using it can terminate, while the underlying // server using it can terminate, while the underlying
// socket is actually left open. // socket is actually left open.
var errFakeClosed = fmt.Errorf("listener 'closed' 😉") var errFakeClosed = fmt.Errorf("QUIC listener 'closed' 😉")
type fakeCloseQuicListener struct { type fakeCloseQuicListener struct {
closed int32 // accessed atomically; belongs to this struct only closed int32 // accessed atomically; belongs to this struct only
*sharedQuicListener // embedded, so we also become a quic.EarlyListener *sharedQuicListener // embedded, so we also become a quic.EarlyListener
context context.Context context context.Context
contextCancel context.CancelFunc contextCancel context.CancelCauseFunc
} }
// Currently Accept ignores the passed context, however a situation where // Currently Accept ignores the passed context, however a situation where
@@ -637,7 +637,7 @@ func (fcql *fakeCloseQuicListener) Accept(_ context.Context) (*quic.Conn, error)
func (fcql *fakeCloseQuicListener) Close() error { func (fcql *fakeCloseQuicListener) Close() error {
if atomic.CompareAndSwapInt32(&fcql.closed, 0, 1) { if atomic.CompareAndSwapInt32(&fcql.closed, 0, 1) {
fcql.contextCancel() fcql.contextCancel(errFakeClosed)
} else if atomic.CompareAndSwapInt32(&fcql.closed, 1, 2) { } else if atomic.CompareAndSwapInt32(&fcql.closed, 1, 2) {
_, _ = listenerPool.Delete(fcql.sharedQuicListener.key) _, _ = listenerPool.Delete(fcql.sharedQuicListener.key)
} }

View File

@@ -18,6 +18,7 @@ import (
"cmp" "cmp"
"context" "context"
"crypto/tls" "crypto/tls"
"errors"
"fmt" "fmt"
"maps" "maps"
"net" "net"
@@ -711,9 +712,10 @@ func (app *App) Stop() error {
// enforce grace period if configured // enforce grace period if configured
if app.GracePeriod > 0 { if app.GracePeriod > 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, time.Duration(app.GracePeriod)) timeout := time.Duration(app.GracePeriod)
ctx, cancel = context.WithTimeoutCause(ctx, timeout, fmt.Errorf("server graceful shutdown %ds timeout", int(timeout.Seconds())))
defer cancel() defer cancel()
app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", time.Duration(app.GracePeriod))) app.logger.Info("servers shutting down; grace period initiated", zap.Duration("duration", timeout))
} else { } else {
app.logger.Info("servers shutting down with eternal grace period") app.logger.Info("servers shutting down with eternal grace period")
} }
@@ -739,6 +741,9 @@ func (app *App) Stop() error {
} }
if err := server.server.Shutdown(ctx); err != nil { if err := server.server.Shutdown(ctx); err != nil {
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
err = cause
}
app.logger.Error("server shutdown", app.logger.Error("server shutdown",
zap.Error(err), zap.Error(err),
zap.Strings("addresses", server.Listen)) zap.Strings("addresses", server.Listen))
@@ -762,6 +767,9 @@ func (app *App) Stop() error {
} }
if err := server.h3server.Shutdown(ctx); err != nil { if err := server.h3server.Shutdown(ctx); err != nil {
if cause := context.Cause(ctx); cause != nil && errors.Is(err, context.DeadlineExceeded) {
err = cause
}
app.logger.Error("HTTP/3 server shutdown", app.logger.Error("HTTP/3 server shutdown",
zap.Error(err), zap.Error(err),
zap.Strings("addresses", server.Listen)) zap.Strings("addresses", server.Listen))

View File

@@ -448,7 +448,7 @@ func (h *HTTPTransport) NewTransport(caddyCtx caddy.Context) (*http.Transport, e
// complete the handshake before returning the connection // complete the handshake before returning the connection
if rt.TLSHandshakeTimeout != 0 { if rt.TLSHandshakeTimeout != 0 {
var cancel context.CancelFunc var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, rt.TLSHandshakeTimeout) ctx, cancel = context.WithTimeoutCause(ctx, rt.TLSHandshakeTimeout, fmt.Errorf("HTTP transport TLS handshake %ds timeout", int(rt.TLSHandshakeTimeout.Seconds())))
defer cancel() defer cancel()
} }
err = tlsConn.HandshakeContext(ctx) err = tlsConn.HandshakeContext(ctx)