自学内容网 自学内容网

40分钟学 Go 语言高并发:Pipeline模式(一)

Pipeline模式

一、课程概述

学习要点重要程度掌握目标
流水线设计★★★★★掌握Pipeline基本结构和设计原则
扇入扇出★★★★☆理解并实现多输入多输出的Pipeline
错误传播★★★★★掌握Pipeline中的错误处理机制
吞吐量优化★★★★☆学会优化Pipeline的性能和吞吐量

二、Pipeline模式基础

让我们首先实现一个基础的Pipeline框架:

package pipeline

import (
    "context"
    "fmt"
    "sync"
)

// Stage 代表Pipeline中的一个阶段
type Stage func(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error)

// Pipeline 代表一个完整的处理流水线
type Pipeline struct {
    stages []Stage
    errCh  chan error
}

// New 创建新的Pipeline
func New(stages ...Stage) *Pipeline {
    return &Pipeline{
        stages: stages,
        errCh:  make(chan error, len(stages)),
    }
}

// Run 运行Pipeline
func (p *Pipeline) Run(ctx context.Context, in <-chan interface{}) (<-chan interface{}, <-chan error) {
    out := in
    var err error

    // 按顺序执行每个Stage
    for i, stage := range p.stages {
        out, err = stage(ctx, out)
        if err != nil {
            p.errCh <- fmt.Errorf("stage %d failed: %v", i, err)
            close(p.errCh)
            return nil, p.errCh
        }
    }

    return out, p.errCh
}

// Merge 合并多个channel的数据(扇入)
func Merge(ctx context.Context, channels ...<-chan interface{}) <-chan interface{} {
    var wg sync.WaitGroup
    out := make(chan interface{})

    // 为每个输入channel启动一个goroutine
    output := func(c <-chan interface{}) {
        defer wg.Done()
        for n := range c {
            select {
            case out <- n:
            case <-ctx.Done():
                return
            }
        }
    }

    wg.Add(len(channels))
    for _, c := range channels {
        go output(c)
    }

    // 当所有输入channel都关闭后,关闭输出channel
    go func() {
        wg.Wait()
        close(out)
    }()

    return out
}

// Split 将一个channel的数据分配给多个处理goroutine(扇出)
func Split(ctx context.Context, in <-chan interface{}, n int) []<-chan interface{} {
    outs := make([]<-chan interface{}, n)
    for i := 0; i < n; i++ {
        outs[i] = make(chan interface{})
    }

    distribute := func(ch chan<- interface{}) {
        defer close(ch)
        for n := range in {
            select {
            case ch <- n:
            case <-ctx.Done():
                return
            }
        }
    }

    for i := 0; i < n; i++ {
        go distribute(outs[i].(chan interface{}))
    }

    return outs
}

让我们实现一个具体的示例 - 数字处理Pipeline:

package main

import (
    "context"
    "fmt"
    "log"
    "time"
)

// 生成器,生成1到n的数字
func generator(ctx context.Context, n int) (<-chan interface{}, error) {
    out := make(chan interface{})
    go func() {
        defer close(out)
        for i := 1; i <= n; i++ {
            select {
            case out <- i:
            case <-ctx.Done():
                return
            }
        }
    }()
    return out, nil
}

// 平方计算Stage
func square(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
    out := make(chan interface{})
    go func() {
        defer close(out)
        for n := range in {
            num, ok := n.(int)
            if !ok {
                continue
            }
            select {
            case out <- num * num:
            case <-ctx.Done():
                return
            }
        }
    }()
    return out, nil
}

// 过滤Stage:只保留能被3整除的数
func filter(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
    out := make(chan interface{})
    go func() {
        defer close(out)
        for n := range in {
            num, ok := n.(int)
            if !ok {
                continue
            }
            if num%3 == 0 {
                select {
                case out <- num:
                case <-ctx.Done():
                    return
                }
            }
        }
    }()
    return out, nil
}

func main() {
    // 创建Context
    ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
    defer cancel()

    // 创建Pipeline
    p := New(
        func(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
            return square(ctx, in)
        },
        func(ctx context.Context, in <-chan interface{}) (<-chan interface{}, error) {
            return filter(ctx, in)
        },
    )

    // 生成输入数据
    input, err := generator(ctx, 10)
    if err != nil {
        log.Fatalf("Generator failed: %v", err)
    }

    // 运行Pipeline
    output, errCh := p.Run(ctx, input)

    // 处理输出和错误
    for {
        select {
        case n, ok := <-output:
            if !ok {
                return
            }
            fmt.Printf("Output: %v\n", n)
        case err := <-errCh:
            if err != nil {
                log.Printf("Pipeline error: %v", err)
                return
            }
        case <-ctx.Done():
            fmt.Println("Pipeline cancelled")
            return
        }
    }
}

三、Pipeline流程图

在这里插入图片描述

四、高级Pipeline实现

让我们实现一个更复杂的Pipeline,包含错误处理和性能优化:

package pipeline

import (
    "context"
    "fmt"
    "runtime"
    "sync"
    "time"
)

// Result 包含处理结果和错误信息
type Result struct {
    Value interface{}
    Err   error
}

// StageFunc 定义处理函数类型
type StageFunc func(interface{}) (interface{}, error)

// Options Pipeline配置选项
type Options struct {
    BufferSize  int           // channel缓冲区大小
    NumWorkers  int           // 工作goroutine数量
    Timeout     time.Duration // 处理超时时间
    RetryCount  int           // 重试次数
    RetryDelay  time.Duration // 重试延迟
}

// AdvancedPipeline 高级Pipeline实现
type AdvancedPipeline struct {
    stages   []StageFunc
    options  Options
    metrics  *Metrics
    errHandler func(error) error
}

// Metrics 性能指标
type Metrics struct {
    mu            sync.RWMutex
    processedItems int64
    errorCount     int64
    avgProcessTime time.Duration
}

// NewAdvanced 创建高级Pipeline
func NewAdvanced(opts Options, stages ...StageFunc) *AdvancedPipeline {
    if opts.NumWorkers <= 0 {
        opts.NumWorkers = runtime.NumCPU()
    }

    return &AdvancedPipeline{
        stages:  stages,
        options: opts,
        metrics: &Metrics{},
    }
}

// SetErrorHandler 设置错误处理函数
func (p *AdvancedPipeline) SetErrorHandler(handler func(error) error) {
    p.errHandler = handler
}

// Process 处理数据
func (p *AdvancedPipeline) Process(ctx context.Context, input <-chan interface{}) (<-chan Result, error) {
    if len(p.stages) == 0 {
        return nil, fmt.Errorf("no stages defined")
    }

    output := make(chan Result, p.options.BufferSize)
    var wg sync.WaitGroup

    // 创建工作池
    for i := 0; i < p.options.NumWorkers; i++ {
        wg.Add(1)
        go func(workerID int) {
            defer wg.Done()
            p.worker(ctx, workerID, input, output)
        }(i)
    }

    // 等待所有工作完成后关闭输出channel
    go func() {
        wg.Wait()
        close(output)
    }()

    return output, nil
}

// worker 工作goroutine
func (p *AdvancedPipeline) worker(ctx context.Context, id int, input <-chan interface{}, output chan<- Result) {
    for data := range input {
        // 处理每个输入项
        startTime := time.Now()
        result := p.processItem(ctx, data)

        // 更新指标
        p.updateMetrics(startTime, result.Err != nil)

        // 发送结果
        select {
        case output <- result:
        case <-ctx.Done():
            return
        }
    }
}

// processItem 处理单个数据项
func (p *AdvancedPipeline) processItem(ctx context.Context, data interface{}) Result {
    var value interface{} = data
    var err error

    // 执行每个阶段
    for i, stage := range p.stages {
        value, err = p.executeStageWithRetry(ctx, stage, value)
        if err != nil {
            if p.errHandler != nil {
                if handlerErr := p.errHandler(err); handlerErr != nil {
                    err = fmt.Errorf("stage %d failed: %v (handler error: %v)", i, err, handlerErr)
                }
            }
            return Result{Err: err}
        }
    }

    return Result{Value: value}
}

// executeStageWithRetry 带重试的阶段执行
func (p *AdvancedPipeline) executeStageWithRetry(ctx context.Context, stage StageFunc, data interface{}) (interface{}, error) {
    var lastErr error

    for attempt := 0; attempt <= p.options.RetryCount; attempt++ {
        // 创建带超时的Context
        timeoutCtx, cancel := context.WithTimeout(ctx, p.options.Timeout)
        
        // 执行阶段处理
        done := make(chan struct{})
        var result interface{}
        var err error

        go func() {
            result, err = stage(data)
            close(done)
        }()

        // 等待处理完成或超时
        select {
        case <-done:
            cancel()
            if err == nil {
                return result, nil
            }
            lastErr = err
        case <-timeoutCtx.Done():
            cancel()
            lastErr = fmt.Errorf("stage timeout")
        }

        // 如果不是最后一次重试,则等待后继续
        if attempt < p.options.RetryCount {
            select {
            case <-time.After(p.options.RetryDelay):
            case <-ctx.Done():
                return nil, ctx.Err()
            }
        }
    }

    return nil, fmt.Errorf("all retry attempts failed: %v", lastErr)
}

// updateMetrics 更新性能指标
func (p *AdvancedPipeline) updateMetrics(startTime time.Time, hasError bool) {
    p.metrics.mu.Lock()
    defer p.metrics.mu.Unlock()

    p.metrics.processedItems++
    if hasError {
        p.metrics.errorCount++
    }

    // 更新平均处理时间
    processingTime := time.Since(startTime)
    if p.metrics.avgProcessTime == 0 {
        p.metrics.avgProcessTime = processingTime
    } else {
        p.metrics.avgProcessTime = (p.metrics.avgProcessTime + processingTime) / 2
    }
}

// GetMetrics 获取性能指标
func (p *AdvancedPipeline) GetMetrics() (int64, int64, time.Duration) {
    p.metrics.mu.RLock()
    defer p.metrics.mu.RUnlock()
    return p.metrics.processedItems, p.metrics.errorCount, p.metrics.avgProcessTime
}

// Reset 重置性能指标
func (p *AdvancedPipeline) Reset() {
    p.metrics.mu.Lock()
    defer p.metrics.mu.Unlock()
    p.metrics.processedItems = 0
    p.metrics.errorCount = 0
    p.metrics.avgProcessTime = 0
}

// WithBufferSize 设置缓冲区大小
func (p *AdvancedPipeline) WithBufferSize(size int) *AdvancedPipeline {
    p.options.BufferSize = size
    return p
}

// WithWorkers 设置工作goroutine数量
func (p *AdvancedPipeline) WithWorkers(n int) *AdvancedPipeline {
    p.options.NumWorkers = n
    return p
}

// WithTimeout 设置处理超时时间
func (p *AdvancedPipeline) WithTimeout(timeout time.Duration) *AdvancedPipeline {
    p.options.Timeout = timeout
    return p
}

// WithRetry 设置重试策略
func (p *AdvancedPipeline) WithRetry(count int, delay time.Duration) *AdvancedPipeline {
    p.options.RetryCount = count
    p.options.RetryDelay = delay
    return p
}

让我们继续完成高级Pipeline的实现。

现在让我们实现一个具体的使用示例:

五、Pipeline模式最佳实践

让我们总结Pipeline模式的最佳实践和性能优化策略:

5.1 设计原则

  1. 单一职责
  • 每个Stage只负责一个特定的任务
  • 保持Stage功能的独立性
  • 便于测试和维护
  1. 错误处理
  • 及时发现和处理错误
  • 提供错误重试机制
  • 支持优雅降级
  1. 资源管理
  • 合理控制goroutine数量
  • 适当设置channel缓冲区
  • 注意资源释放
  1. 可扩展性
  • 支持动态添加/删除Stage
  • 提供清晰的接口定义
  • 支持自定义配置

5.2 性能优化

让我们实现一个性能监控组件:

package pipeline

import (
    "fmt"
    "sync/atomic"
    "time"
)

// PipelineMonitor 性能监控组件
type PipelineMonitor struct {
    startTime     time.Time
    processCount  int64
    errorCount    int64
    totalLatency  int64
    maxLatency    int64
    stageMetrics  map[string]*StageMetrics
}

// StageMetrics 单个Stage的性能指标
type StageMetrics struct {
    processCount int64
    errorCount   int64
    totalLatency int64
    maxLatency   int64
}

// NewPipelineMonitor 创建监控组件
func NewPipelineMonitor() *PipelineMonitor {
    return &PipelineMonitor{
        startTime:    time.Now(),
        stageMetrics: make(map[string]*StageMetrics),
    }
}

// RecordProcessing 记录处理情况
func (m *PipelineMonitor) RecordProcessing(stageName string, latency time.Duration, err error) {
    atomic.AddInt64(&m.processCount, 1)
    atomic.AddInt64(&m.totalLatency, int64(latency))
    
    // 更新最大延迟
    for {
        current := atomic.LoadInt64(&m.maxLatency)
        if current >= int64(latency) {
            break
        }
        if atomic.CompareAndSwapInt64(&m.maxLatency, current, int64(latency)) {
            break
        }
    }

    if err != nil {
        atomic.AddInt64(&m.errorCount, 1)
    }

    // 更新Stage指标
    metrics, ok := m.stageMetrics[stageName]
    if !ok {
        metrics = &StageMetrics{}
        m.stageMetrics[stageName] = metrics
    }

    atomic.AddInt64(&metrics.processCount, 1)
    atomic.AddInt64(&metrics.totalLatency, int64(latency))
    
    if err != nil {
        atomic.AddInt64(&metrics.errorCount, 1)
    }

    // 更新Stage最大延迟
    for {
        current := atomic.LoadInt64(&metrics.maxLatency)
        if current >= int64(latency) {
            break
        }
        if atomic.CompareAndSwapInt64(&metrics.maxLatency, current, int64(latency)) {
            break
        }
    }
}

// GetMetrics 获取性能指标
func (m *PipelineMonitor) GetMetrics() string {
    uptime := time.Since(m.startTime)
    processCount := atomic.LoadInt64(&m.processCount)
    errorCount := atomic.LoadInt64(&m.errorCount)
    totalLatency := time.Duration(atomic.LoadInt64(&m.totalLatency))
    maxLatency := time.Duration(atomic.LoadInt64(&m.maxLatency))

    var avgLatency time.Duration
    if processCount > 0 {
        avgLatency = totalLatency / time.Duration(processCount)
    }

    result := fmt.Sprintf(
        "Pipeline Metrics:\n"+
            "Uptime: %v\n"+
            "Processed: %d\n"+
            "Errors: %d\n"+
            "Average Latency: %v\n"+
            "Max Latency: %v\n"+
            "Throughput: %.2f/sec\n",
        uptime,
        processCount,
        errorCount,
        avgLatency,
        maxLatency,
        float64(processCount)/uptime.Seconds(),
    )

    result += "\nStage Metrics:\n"
    for name, metrics := range m.stageMetrics {
        stageProcessCount := atomic.LoadInt64(&metrics.processCount)
        stageErrorCount := atomic.LoadInt64(&metrics.errorCount)
        stageTotalLatency := time.Duration(atomic.LoadInt64(&metrics.totalLatency))
        stageMaxLatency := time.Duration(atomic.LoadInt64(&metrics.maxLatency))

        var stageAvgLatency time.Duration
        if stageProcessCount > 0 {
            stageAvgLatency = stageTotalLatency / time.Duration(stageProcessCount)
        }

        result += fmt.Sprintf(
            "  %s:\n"+
                "    Processed: %d\n"+
                "    Errors: %d\n"+
                "    Average Latency: %v\n"+
                "    Max Latency: %v\n"+
                "    Error Rate: %.2f%%\n",
            name,
            stageProcessCount,
            stageErrorCount,
            stageAvgLatency,
            stageMaxLatency,
            float64(stageErrorCount)*100/float64(stageProcessCount),
        )
    }

    return result
}

// Reset 重置监控指标
func (m *PipelineMonitor) Reset() {
    m.startTime = time.Now()
    atomic.StoreInt64(&m.processCount, 0)
    atomic.StoreInt64(&m.errorCount, 0)
    atomic.StoreInt64(&m.totalLatency, 0)
    atomic.StoreInt64(&m.maxLatency, 0)
    m.stageMetrics = make(map[string]*StageMetrics)
}

5.3 Pipeline使用流程图

在这里插入图片描述


原文地址:https://blog.csdn.net/weixin_40780178/article/details/144049684

免责声明:本站文章内容转载自网络资源,如侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!