本文源码版本基于 go1.21.13

标准库中的 Context 是一个接口,其具体实现有很多种,主要用于跨多个 Goroutine 设置截止时间、同步信号、传递上下文请求值等。

type Context interface {
  
  
    Deadline() (deadline time.Time, ok bool)
  
  
  
    Done() <-chan struct{}
  
  
  
  
  
    Err() error
  
  
  
  
  
    
    Value(key any) any
}

Done 的 demo 用法

func Stream(ctx context.Context, out chan<- Value) error {
    for {
        v, err := DoSomething(ctx)
        if err != nil {
            return err
        }
        select {
        case <-ctx.Done():
            return ctx.Err()
        case out <- v:
        }
    }
}

Value 的 demo 用法

package user
import "context"
type User struct {...}
type key int
var userKey key
func NewContext(ctx context.Context, u *User) context.Context {
    return context.WithValue(ctx, userKey, u)
}
func FromContext(ctx context.Context) (*User, bool) {
    u, ok := ctx.Value(userKey).(*User)
    return u, ok
}

# Context 内部类型

# emptyCtx

emptyCtx 是一个没有 cancel 过的,没有 deadline 的,没有值的空 ctx。

emptyCtx 是 backgroundCtx 和 todoCtx 的共同基础。

type emptyCtx struct{}
func (emptyCtx) Deadline() (deadline time.Time, ok bool) {
    return
}
func (emptyCtx) Done() <-chan struct{} {
    return nil
}
func (emptyCtx) Err() error {
    return nil
}
func (emptyCtx) Value(key any) any {
    return nil
}

# backgroundCtx

context.Background() 返回一个 backgroundCtx ,本质上是 emptyCtx,主要用于 主函数、初始化、test case 作为请求的顶级 Contex 传入

type backgroundCtx struct{ emptyCtx }
func (backgroundCtx) String() string {
    return "context.Background"
}
func Background() Context {
    return backgroundCtx{}
}

# todoCtx

context.TODO() 返回一个 todoCtx ,本质上是 emptyCtx,主要用于不清楚要用哪个 context 或者其他函数还没有开始定义接受 ctx 的参数,

type todoCtx struct{ emptyCtx }
func (todoCtx) String() string {
    return "context.TODO"
}
func TODO() Context {
    return todoCtx{}
}

# cancelCtx

可以被 canceled 的 ctx,当它 canceled 时,也会取消它的 children

type cancelCtx struct {
    Context
    mu       sync.Mutex            
    done     atomic.Value          
    children map[canceler]struct{} 
    err      error                 
    cause    error                 
}

具体的取消逻辑见下文

# stopCtx

stopCtx 被用作 cancelCtx 的 parent context 当一个 AfterFunc 已在 parent context 中注册时.
它包含用于取消注册 AfterFunc 的停止函数。

type stopCtx struct {
    Context
    stop func() bool
}

# timerCtx

timerCtx 带有一个计时器和一个截止时间。

通过停止计时器然后委托给 cancelCtx.cancel 来实现特定时间取消。

嵌入了 cancelCtx 来实现 Done 和 Err 操作。

type timerCtx struct {
    cancelCtx
    timer *time.Timer 
    deadline time.Time
}

# valueCtx

valueCtx 携带了两个无限制变量 key, val any 。嵌入了 Context 来实现其他调用。

type valueCtx struct {
    Context
    key, val any
}

# cancel

type cancelCtx struct {
    Context
    mu       sync.Mutex            
  done     atomic.Value          
    children map[canceler]struct{} 
    err      error                 
    cause    error                 
}

# 创建 cancelCtx

type CancelFunc func()
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
    c := withCancel(parent)
    return c, func() { c.cancel(true, Canceled, nil) }
}
type CancelCauseFunc func(cause error)
func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) {
    c := withCancel(parent)
    return c, func(cause error) { c.cancel(true, Canceled, cause) }
}
func withCancel(parent Context) *cancelCtx {
    if parent == nil {
        panic("cannot create context from nil parent")
    }
    c := &cancelCtx{}
    c.propagateCancel(parent, c)
    return c
}
func (c *cancelCtx) propagateCancel(parent Context, child canceler) {
    c.Context = parent
    done := parent.Done()
    if done == nil {
        return 
    }
  
    
    
    select {
    case <-done:
        
        child.cancel(false, parent.Err(), Cause(parent))
        return
    default:
    }
    
  
    if p, ok := parentCancelCtx(parent); ok {
        
        p.mu.Lock()
        if p.err != nil {
            
            child.cancel(false, p.err, p.cause)
        } else {
            if p.children == nil {
                p.children = make(map[canceler]struct{})
            }
            p.children[child] = struct{}{}
        }
        p.mu.Unlock()
        return
    }
    if a, ok := parent.(afterFuncer); ok {
        
        c.mu.Lock()
        stop := a.AfterFunc(func() {
            child.cancel(false, parent.Err(), Cause(parent))
        })
        c.Context = stopCtx{
            Context: parent,
            stop:    stop,
        }
        c.mu.Unlock()
        return
    }
  
    goroutines.Add(1)
    go func() {
        select {
        case <-parent.Done():
            child.cancel(false, parent.Err(), Cause(parent))
        case <-child.Done():
        }
    }()
}
func parentCancelCtx(parent Context) (*cancelCtx, bool) {
    done := parent.Done()
  
  
    if done == closedchan || done == nil {
        return nil, false
    }
  
    
    p, ok := parent.Value(&cancelCtxKey).(*cancelCtx)
    if !ok {
        return nil, false
    }
  
  
    
    
    
    
    
    
    
    
    
    
    
    
    pdone, _ := p.done.Load().(chan struct{})
    if pdone != done {
        return nil, false
    }
    return p, true
}

# 取消 cancelCtx

cancelCtx 内部跨多个 Goroutine 实现信号传递其实靠的就是一个 done channel;如果要取消这个 Context,那么就需要让所有 <-c.Done() 停止阻塞,这时候最简单的办法就是把这个 channel 直接 close 掉,或者干脆换成一个已经被 close 的 channel

type canceler interface {
    cancel(removeFromParent bool, err, cause error)
    Done() <-chan struct{}
}
func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
    if err == nil {
        panic("context: internal error: missing cancel error")
    }
    if cause == nil {
        cause = err
    }
  
    c.mu.Lock()
  
    if c.err != nil {
        c.mu.Unlock()
        return 
    }
    c.err = err
    c.cause = cause
    d, _ := c.done.Load().(chan struct{})
    if d == nil {
    
        c.done.Store(closedchan)
    } else {
    
        close(d)
    }
    for child := range c.children {
        
        child.cancel(false, err, cause)
    }
    c.children = nil
    c.mu.Unlock()
  
    if removeFromParent {
        removeChild(c.Context, c)
    }
}
func removeChild(parent Context, child canceler) {
    if s, ok := parent.(stopCtx); ok {
        s.stop()
        return
    }
    p, ok := parentCancelCtx(parent)
    if !ok {
        return
    }
    p.mu.Lock()
    if p.children != nil {
        delete(p.children, child)
    }
    p.mu.Unlock()
}

# timerCtx

type timerCtx struct {
    cancelCtx
    timer *time.Timer 
    deadline time.Time
}

# 创建 timerCtx

timerCtx 的创建主要通过 context.WithDeadline 方法,

同时 context.WithTimeoutWithTimeoutCause 实际上也是调用的 context.WithDeadlineCause :

func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) {
    return WithDeadline(parent, time.Now().Add(timeout))
}
func WithTimeoutCause(parent Context, timeout time.Duration, cause error) (Context, CancelFunc) {
    return WithDeadlineCause(parent, time.Now().Add(timeout), cause)
}
func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
    return WithDeadlineCause(parent, d, nil)
}
func WithDeadlineCause(parent Context, d time.Time, cause error) (Context, CancelFunc) {
    if parent == nil {
        panic("cannot create context from nil parent")
    }
    if cur, ok := parent.Deadline(); ok && cur.Before(d) {
        
        return WithCancel(parent)
    }
    c := &timerCtx{
        deadline: d,
    }
    c.cancelCtx.propagateCancel(parent, c)
    dur := time.Until(d)
    if dur <= 0 {
        c.cancel(true, DeadlineExceeded, cause) 
        return c, func() { c.cancel(false, Canceled, nil) }
    }
    c.mu.Lock()
    defer c.mu.Unlock()
    if c.err == nil {
        c.timer = time.AfterFunc(dur, func() {
            c.cancel(true, DeadlineExceeded, cause)
        })
    }
    return c, func() { c.cancel(true, Canceled, nil) }
}

# 取消 timerCtx

调用一下里面的 cancelCtx 的 cancel,然后再把定时器停掉:

func (c *timerCtx) cancel(removeFromParent bool, err, cause error) {
    c.cancelCtx.cancel(false, err, cause)
    if removeFromParent {
        
        removeChild(c.cancelCtx.Context, c)
    }
    c.mu.Lock()
    if c.timer != nil {
        c.timer.Stop()
        c.timer = nil
    }
    c.mu.Unlock()
}

# valueCtx

type valueCtx struct {
    Context
    key, val any
}
func WithValue(parent Context, key, val any) Context {
    if parent == nil {
        panic("cannot create context from nil parent")
    }
    if key == nil {
        panic("nil key")
    }
    if !reflectlite.TypeOf(key).Comparable() {
        panic("key is not comparable")
    }
    return &valueCtx{parent, key, val}
}
func (c *valueCtx) String() string {
    return contextName(c.Context) + ".WithValue(type " +
        reflectlite.TypeOf(c.key).String() +
        ", val " + stringify(c.val) + ")"
}
func (c *valueCtx) Value(key any) any {
    if c.key == key {
        return c.val
    }
  
    return value(c.Context, key)
}
func value(c Context, key any) any {
    for {
        switch ctx := c.(type) {
        case *valueCtx:
            if key == ctx.key {
                return ctx.val
            }
            c = ctx.Context
        case *cancelCtx:
            if key == &cancelCtxKey {
                return c
            }
            c = ctx.Context
        case withoutCancelCtx:
            if key == &cancelCtxKey {
                
                
                return nil
            }
            c = ctx.c
        case *timerCtx:
            if key == &cancelCtxKey {
                return &ctx.cancelCtx
            }
            c = ctx.Context
        case backgroundCtx, todoCtx:
            return nil
        default:
            return c.Value(key)
        }
    }
}