91 lines
2.3 KiB
Go
91 lines
2.3 KiB
Go
package mongo
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"go.mongodb.org/mongo-driver/bson"
|
|
"go.mongodb.org/mongo-driver/mongo"
|
|
)
|
|
|
|
// MongoDB server error codes for index conflicts.
|
|
const (
|
|
indexOptionsConflictCode = 85 // IndexOptionsConflict
|
|
indexKeySpecsConflictCode = 86 // IndexKeySpecsConflict
|
|
)
|
|
|
|
// EnsureIndexes creates the requested indexes, recovering from conflicts caused
|
|
// by indexes that an earlier schema version created with the same name but
|
|
// different options (e.g. adding a partialFilterExpression during the
|
|
// persona -> brand migration). On a conflict it drops the stale index by its
|
|
// generated name and recreates it with the requested options, so startup does
|
|
// not panic on environments that still hold the legacy index.
|
|
func EnsureIndexes(ctx context.Context, coll *mongo.Collection, models []mongo.IndexModel) error {
|
|
if coll == nil {
|
|
return nil
|
|
}
|
|
for _, model := range models {
|
|
if _, err := coll.Indexes().CreateOne(ctx, model); err != nil {
|
|
if !isIndexConflict(err) {
|
|
return err
|
|
}
|
|
name := indexName(model)
|
|
if name == "" {
|
|
return err
|
|
}
|
|
if _, dropErr := coll.Indexes().DropOne(ctx, name); dropErr != nil {
|
|
return err
|
|
}
|
|
if _, retryErr := coll.Indexes().CreateOne(ctx, model); retryErr != nil {
|
|
return retryErr
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func isIndexConflict(err error) bool {
|
|
var serverErr mongo.ServerError
|
|
if errors.As(err, &serverErr) {
|
|
return serverErr.HasErrorCode(indexOptionsConflictCode) ||
|
|
serverErr.HasErrorCode(indexKeySpecsConflictCode)
|
|
}
|
|
return false
|
|
}
|
|
|
|
// indexName reproduces MongoDB's default index name (key_direction pairs joined
|
|
// by underscores) so a conflicting index can be dropped by name.
|
|
func indexName(model mongo.IndexModel) string {
|
|
if model.Options != nil && model.Options.Name != nil {
|
|
return *model.Options.Name
|
|
}
|
|
keys, ok := model.Keys.(bson.D)
|
|
if !ok {
|
|
return ""
|
|
}
|
|
parts := make([]string, 0, len(keys)*2)
|
|
for _, e := range keys {
|
|
parts = append(parts, e.Key, indexValueToken(e.Value))
|
|
}
|
|
return strings.Join(parts, "_")
|
|
}
|
|
|
|
func indexValueToken(v any) string {
|
|
switch t := v.(type) {
|
|
case int:
|
|
return strconv.Itoa(t)
|
|
case int32:
|
|
return strconv.Itoa(int(t))
|
|
case int64:
|
|
return strconv.FormatInt(t, 10)
|
|
case float64:
|
|
return strconv.FormatInt(int64(t), 10)
|
|
case string:
|
|
return t
|
|
default:
|
|
return ""
|
|
}
|
|
}
|