thread-master/backend/internal/library/mongo/ensure_index.go

91 lines
2.3 KiB
Go

package mongo
import (
"context"
"errors"
"strconv"
"strings"
"go.mongodb.org/mongo-driver/bson"
"go.mongodb.org/mongo-driver/mongo"
)
// MongoDB server error codes for index conflicts.
const (
indexOptionsConflictCode = 85 // IndexOptionsConflict
indexKeySpecsConflictCode = 86 // IndexKeySpecsConflict
)
// EnsureIndexes creates the requested indexes, recovering from conflicts caused
// by indexes that an earlier schema version created with the same name but
// different options (e.g. adding a partialFilterExpression during the
// persona -> brand migration). On a conflict it drops the stale index by its
// generated name and recreates it with the requested options, so startup does
// not panic on environments that still hold the legacy index.
func EnsureIndexes(ctx context.Context, coll *mongo.Collection, models []mongo.IndexModel) error {
if coll == nil {
return nil
}
for _, model := range models {
if _, err := coll.Indexes().CreateOne(ctx, model); err != nil {
if !isIndexConflict(err) {
return err
}
name := indexName(model)
if name == "" {
return err
}
if _, dropErr := coll.Indexes().DropOne(ctx, name); dropErr != nil {
return err
}
if _, retryErr := coll.Indexes().CreateOne(ctx, model); retryErr != nil {
return retryErr
}
}
}
return nil
}
func isIndexConflict(err error) bool {
var serverErr mongo.ServerError
if errors.As(err, &serverErr) {
return serverErr.HasErrorCode(indexOptionsConflictCode) ||
serverErr.HasErrorCode(indexKeySpecsConflictCode)
}
return false
}
// indexName reproduces MongoDB's default index name (key_direction pairs joined
// by underscores) so a conflicting index can be dropped by name.
func indexName(model mongo.IndexModel) string {
if model.Options != nil && model.Options.Name != nil {
return *model.Options.Name
}
keys, ok := model.Keys.(bson.D)
if !ok {
return ""
}
parts := make([]string, 0, len(keys)*2)
for _, e := range keys {
parts = append(parts, e.Key, indexValueToken(e.Value))
}
return strings.Join(parts, "_")
}
func indexValueToken(v any) string {
switch t := v.(type) {
case int:
return strconv.Itoa(t)
case int32:
return strconv.Itoa(int(t))
case int64:
return strconv.FormatInt(t, 10)
case float64:
return strconv.FormatInt(int64(t), 10)
case string:
return t
default:
return ""
}
}