自学内容网 自学内容网

Go 中的泛型,日常如何使用

泛型从 go 的 1.18 开始支持

什么是泛型编程

在泛型出现之前,如果需要计算两数之和,可能会这样写:

func Add(a, b int) int {
  returb a + b
}

这个很简单,但是只能两个参数都是 int 类型的时候才能调用

如果想要计算两个浮点数的和,就需要再定义一个函数

func AddFloat32(a, b float32) float32 {
  return a + b
}

如果需要计算两个字符串的和,就需要再定义一个函数,这样太过麻烦

在一个函数中,行参只是类似占位符的东西,只有调用函数传入实参之后才有具体的值

如果把行参、实参的概念推广一下,给变量的类型也引入类似行参实参的概念,那就可以接收各种类型的值

函数就类似于:

func Add(a, b T) T {
  return a + b
}

在这段代码中,T 可以称为 类型行参(type parameter),它不是具体的类型,在定义函数时参数的类型不确定,在传入时才确定,传入参数的具体类型称为 类型实参(type argument)

这样通过 类型行参类型实参 进行编码的方式就称为 泛型编程

Go 的泛型

除了上面提到的类型行参和类型实参,Go 还引入了其他概念:

  • 类型形参 (Type parameter)

  • 类型实参(Type argument)

  • 类型形参列表( Type parameter list)

  • 类型约束(Type constraint)

  • 实例化(Instantiations)

  • 泛型类型(Generic type)

  • 泛型接收器(Generic receiver)

  • 泛型函数(Generic function)

一个一个说

类型行参、类型实参、类型约束和泛型类型

假如现在要定义一个切片,可以容纳 int、float32、string 等多种类型,不使用泛型常规做法是为每种类型各自定义一个切片

使用泛型,就可以这样定义:

type AllTypeSlice[T int | float32 | string] []T

在这行代码中:

  • T 就是上面介绍的 类型参数,在定义 slice 时 T 的类型不确定,类似于一个占位符

  • int|float|string 称为 类型约束,中间的 | 就是告诉编译器只能接收这几种类型中的一种

  • 中括号[] 中的 T int|float32|float64 这一整串定义了所有的类型实参(这里只有 T 这一个类型行参),称为 类型行参列表

泛型类型不能拿来使用,必须传入类型实参,能确定类型后才能使用,这个过程称为 实例化

func main() {
    type AllTypeSlice[T int | float32 | string] []T
​
    intSlice := AllTypeSlice[int]{}
    fmt.Printf("%T\n", intSlice) // main.AllTypeSlice[int]
​
    floatSlice := AllTypeSlice[float32]{}
    fmt.Printf("%T\n", floatSlice) // main.AllTypeSlice[float32]
}

类型行参的个数可以有多个,比如:

type AllTypeMap[KEY int | string, VALUE float32 | float64] map[KEY]VALUE

这样就定义了一个泛型 map,key 和 value 都可以是多种

func main() {
    type AllTypeMap[KEY int | string, VALUE float32 | float64] map[KEY]VALUE
​
    var a AllTypeMap[string, float64] = map[string]float64{
        "jack_score": 9.6,
        "bob_score":  8.4,
    }
    fmt.Printf("%T\n", a) // main.AllTypeMap[string,float64]
}
其他的泛型类型

所有类型定义都可以使用类型行参,包括结构体、接口、通道

// 泛型结构体
type AllTypeStruct[T int | string] struct {
  Name string
  Data T
}
​
// 泛型接口
type PrintData[T int | float32] interface {
  Print(data T)
}
​
// 泛型通道
type AllTypeChan[T int | string] chan T
类型行参的嵌套

类型行参是可以嵌套使用的,比如:

type AllTypeStruct[T int | float32, S []T] struct {
    Data     S
    MaxValue T
    MinValue T
}

使用:

s1 := AllTypeStruct[int, []int]{}
s2 := AllTypeStruct[int, []float32]{} // 报错 类型需要一致

即使 T 可以从 int 和 float32 中选,但是传入时类型已经确定了,所以 T 的类型和 []T 中 T 的类型要保持一致

泛型的使用

泛型一般用来实现一些不需要关注具体类型就可以使用的通用方法

给出一个比较常用的场景:Gorm 中使用泛型写一些基础方法,不同类型调用这些方法,会使用对应的 model,通过相同的逻辑查不同的表,而不需要再分别写各自的方法

常见的一些通用方法:

func IsErrRecordNotFound(err error) bool {
    return errors.Is(err, gorm.ErrRecordNotFound)
}
​
type baseRepo[T any] struct {
    db *gorm.DB
}
​
func NewBaseRepo[T any](db *gorm.DB) baseRepo[T] {
    return baseRepo[T]{db: db}
}
​
func (r *baseRepo[T]) GetDB() *gorm.DB {
    return r.db
}
​
// GetByID 通过 id 查单条记录
func (r *baseRepo[T]) GetByID(id int, preloads ...string) (*T, error) {
    var result T
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
    if err := db.First(&result, id).Error; err != nil {
        return nil, err
    }
    return &result, nil
}
​
// GetByIds 通过 ids 用 IN 查结果集
func (r *baseRepo[T]) GetByIds(ids []int, preloads ...string) (list []*T, err error) {
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
    err = db.Unscoped().Where("id IN ?", ids).Find(&list).Error
    return
}
​
// GetFirst 根据条件查第一条记录 传入对应类型结构体
func (r *baseRepo[T]) GetFirst(cond T, preloads ...string) (*T, error) {
    var result T
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
    if err := db.First(&result, cond).Error; err != nil {
        if IsErrRecordNotFound(err) {
            return nil, nil
        }
        return nil, err
    }
    return &result, nil
}
​
// GetList 通过条件查所有记录
func (r *baseRepo[T]) GetList(cond T, preloads ...string) ([]*T, error) {
    var list []*T
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
    if err := db.Find(&list, cond).Error; err != nil {
        return nil, err
    }
    return list, nil
}
​
// GetListWithOrder 根据条件查询所有记录并排序
func (r *baseRepo[T]) GetListWithOrder(cond T, order string, preloads ...string) ([]*T, error) {
    var list []*T
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
    if err := db.Order(order).Find(&list, cond).Error; err != nil {
        return nil, err
    }
    return list, nil
}
​
// GetPage 分页查询记录
func (r *baseRepo[T]) GetPage(cond *T, order string, pageNo, pageSize int, preloads ...string) ([]*T, error) {
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
​
    offset := (pageNo - 1) * pageSize
    limit := pageSize
    db = db.Order(order).Offset(offset).Limit(limit)
    if cond != nil {
        db = db.Where(cond)
    }
​
    var list []*T
    if err := db.Find(&list).Error; err != nil {
        return nil, err
    }
    return list, nil
}
​
// LikeWithOrder 模糊查询并排序
func (r *baseRepo[T]) LikeWithOrder(columns []string, keyword, order string, preloads ...string) ([]*T, error) {
    var list []*T
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
    like := "%" + keyword + "%"
    for _, column := range columns {
        query := fmt.Sprintf("%s like ?", column)
        db = db.Or(query, like)
    }
    if err := db.Order(order).Find(&list).Error; err != nil {
        return nil, err
    }
    return list, nil
}
​
// GetAll 查询所有记录
func (r *baseRepo[T]) GetAll(preloads ...string) ([]*T, error) {
    var list []*T
    db := r.db
    for _, preload := range preloads {
        db = db.Preload(preload)
    }
    if err := db.Find(&list).Error; err != nil {
        return nil, err
    }
    return list, nil
}
​
// GetIds 通过条件查 ids
func (r *baseRepo[T]) GetIds(cond T) ([]int, error) {
    var ids []int
    model := new(T)
    if err := r.db.Model(model).Where(cond).Pluck("id", &ids).Error; err != nil {
        return nil, err
    }
    return ids, nil
}
​
// UpdateById 更新单个属性
func (r *baseRepo[T]) UpdateById(id int, column string, value any) error {
    m := new(T)
    return r.db.Model(m).Where("id = ?", id).Update(column, value).Error
}
​
// UpdatesById 更新多个属性值,可以传 map 或者结构体
func (r *baseRepo[T]) UpdatesById(id int, updateInfo interface{}) error {
    m := new(T)
    return r.db.Model(m).Where("id = ?", id).Updates(updateInfo).Error
}
​
// DeleteByID 删除一条记录
func (r *baseRepo[T]) DeleteByID(id int, force ...bool) error {
    m := new(T)
    session := r.db.Model(m)
    if len(force) > 0 && force[0] == true {
        session = session.Unscoped()
    }
    return session.Delete(m, id).Error
}
​
// DeleteBatch 删除多条记录
func (r *baseRepo[T]) DeleteBatch(ids []int, force ...bool) error {
    m := new(T)
    session := r.db.Where("id IN ?", ids)
    if len(force) > 0 && force[0] == true {
        session = session.Unscoped()
    }
    return session.Delete(m).Error
}
​
// DeleteWith 通过条件删除记录
func (r *baseRepo[T]) DeleteWith(cond T, force ...bool) error {
    m := new(T)
    session := r.db.Where(cond)
    if len(force) > 0 && force[0] == true {
        session = session.Unscoped()
    }
    return session.Delete(m).Error
}
​
// Count 查询记录条目数
func (r *baseRepo[T]) Count(cond T) (int, error) {
    var num int64
    m := new(T)
    err := r.db.Model(m).Where(cond).Count(&num).Error
    return int(num), err
}

这些方法都是一些逻辑比较简单的通用方法,在使用时,只需要通过结构体嵌套的方式注入 baseRepo ,在查询时就可以直接使用,很方便

结语

对于泛型,我认为只需要掌握基础的使用方法,一些细节和高级用法需要用到时再去查询比较合适,重点在于日常使用

参考(真的很详细,推荐阅读):后端 - Go 1.18 泛型全面讲解:一篇讲清泛型的全部 - 个人文章 - SegmentFault 思否


原文地址:https://blog.csdn.net/Shoulen/article/details/143642756

免责声明:本站文章内容转载自网络资源,如本站内容侵犯了原著者的合法权益,可联系本站删除。更多内容请关注自学内容网(zxcms.com)!