package mongo import ( "context" "errors" "strconv" "strings" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" ) // MongoDB server error codes for index conflicts. const ( indexOptionsConflictCode = 85 // IndexOptionsConflict indexKeySpecsConflictCode = 86 // IndexKeySpecsConflict ) // EnsureIndexes creates the requested indexes, recovering from conflicts caused // by indexes that an earlier schema version created with the same name but // different options (e.g. adding a partialFilterExpression during the // persona -> brand migration). On a conflict it drops the stale index by its // generated name and recreates it with the requested options, so startup does // not panic on environments that still hold the legacy index. func EnsureIndexes(ctx context.Context, coll *mongo.Collection, models []mongo.IndexModel) error { if coll == nil { return nil } for _, model := range models { if _, err := coll.Indexes().CreateOne(ctx, model); err != nil { if !isIndexConflict(err) { return err } name := indexName(model) if name == "" { return err } if _, dropErr := coll.Indexes().DropOne(ctx, name); dropErr != nil { return err } if _, retryErr := coll.Indexes().CreateOne(ctx, model); retryErr != nil { return retryErr } } } return nil } func isIndexConflict(err error) bool { var serverErr mongo.ServerError if errors.As(err, &serverErr) { return serverErr.HasErrorCode(indexOptionsConflictCode) || serverErr.HasErrorCode(indexKeySpecsConflictCode) } return false } // indexName reproduces MongoDB's default index name (key_direction pairs joined // by underscores) so a conflicting index can be dropped by name. func indexName(model mongo.IndexModel) string { if model.Options != nil && model.Options.Name != nil { return *model.Options.Name } keys, ok := model.Keys.(bson.D) if !ok { return "" } parts := make([]string, 0, len(keys)*2) for _, e := range keys { parts = append(parts, e.Key, indexValueToken(e.Value)) } return strings.Join(parts, "_") } func indexValueToken(v any) string { switch t := v.(type) { case int: return strconv.Itoa(t) case int32: return strconv.Itoa(int(t)) case int64: return strconv.FormatInt(t, 10) case float64: return strconv.FormatInt(int64(t), 10) case string: return t default: return "" } }