129 lines
3.7 KiB
Go
129 lines
3.7 KiB
Go
|
|
package auth
|
||
|
|
|
||
|
|
import (
|
||
|
|
"context"
|
||
|
|
|
||
|
|
"gateway/internal/library/zitadel"
|
||
|
|
authmetaenum "gateway/internal/model/auth/domain/enum"
|
||
|
|
domauth "gateway/internal/model/auth/domain/usecase"
|
||
|
|
memberdom "gateway/internal/model/member/domain"
|
||
|
|
dommember "gateway/internal/model/member/domain/usecase"
|
||
|
|
"gateway/internal/svc"
|
||
|
|
"gateway/internal/types"
|
||
|
|
|
||
|
|
"github.com/zeromicro/go-zero/core/logx"
|
||
|
|
)
|
||
|
|
|
||
|
|
type RegisterSocialCallbackLogic struct {
|
||
|
|
logx.Logger
|
||
|
|
ctx context.Context
|
||
|
|
svcCtx *svc.ServiceContext
|
||
|
|
}
|
||
|
|
|
||
|
|
func NewRegisterSocialCallbackLogic(ctx context.Context, svcCtx *svc.ServiceContext) *RegisterSocialCallbackLogic {
|
||
|
|
return &RegisterSocialCallbackLogic{
|
||
|
|
Logger: logx.WithContext(ctx),
|
||
|
|
ctx: ctx,
|
||
|
|
svcCtx: svcCtx,
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
func (l *RegisterSocialCallbackLogic) RegisterSocialCallback(req *types.RegisterSocialCallbackReq) (*types.AuthTokenData, error) {
|
||
|
|
if l.svcCtx.Zitadel == nil || l.svcCtx.AuthRegistrationSession == nil {
|
||
|
|
return nil, errb.SysNotImplemented("social registration not configured")
|
||
|
|
}
|
||
|
|
if l.svcCtx.MemberProvisioning == nil || l.svcCtx.MemberProfile == nil {
|
||
|
|
return nil, errb.SysNotImplemented("member provisioning not configured")
|
||
|
|
}
|
||
|
|
if l.svcCtx.AuthToken == nil {
|
||
|
|
return nil, errb.SysNotImplemented("auth token not configured")
|
||
|
|
}
|
||
|
|
|
||
|
|
sessionID, err := parseRegisterOAuthState(req.State)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
session, err := l.svcCtx.AuthRegistrationSession.Get(l.ctx, sessionID)
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
defer func() {
|
||
|
|
if delErr := l.svcCtx.AuthRegistrationSession.Delete(l.ctx, sessionID); delErr != nil {
|
||
|
|
logx.WithContext(l.ctx).Errorf("register social callback: delete session: %v", delErr)
|
||
|
|
}
|
||
|
|
}()
|
||
|
|
|
||
|
|
tok, err := l.svcCtx.Zitadel.ExchangeAuthorizationCode(l.ctx, req.Code, session.RedirectURI)
|
||
|
|
if err != nil {
|
||
|
|
return nil, wrapZitadelErr(err)
|
||
|
|
}
|
||
|
|
var claims *zitadel.IDTokenClaims
|
||
|
|
if tok.IDToken != "" {
|
||
|
|
claims, err = l.svcCtx.Zitadel.VerifyIDToken(l.ctx, tok.IDToken)
|
||
|
|
} else {
|
||
|
|
claims, err = zitadelIdentityFromToken(l.ctx, l.svcCtx.Zitadel, tok)
|
||
|
|
}
|
||
|
|
if err != nil {
|
||
|
|
return nil, wrapZitadelErr(err)
|
||
|
|
}
|
||
|
|
|
||
|
|
if !claims.EmailVerified {
|
||
|
|
return nil, errb.AuthForbidden("social email is not verified")
|
||
|
|
}
|
||
|
|
|
||
|
|
isExisting := false
|
||
|
|
if _, err := l.svcCtx.MemberProfile.GetByZitadelUserID(l.ctx, session.TenantID, claims.Sub); err == nil {
|
||
|
|
isExisting = true
|
||
|
|
} else if !isMemberNotFound(err) {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
if isExisting && session.InviteNewUsersOnly {
|
||
|
|
return nil, errb.ResAlreadyExist("account already exists, please login").WithCause(memberdom.ErrDuplicateMember)
|
||
|
|
}
|
||
|
|
|
||
|
|
var inviteCodeID string
|
||
|
|
if l.svcCtx.Config.Member.Defaults().Registration.RequireInviteCode {
|
||
|
|
if l.svcCtx.AuthInvite == nil {
|
||
|
|
return nil, errb.SysNotImplemented("invite validation not configured")
|
||
|
|
}
|
||
|
|
consumed, err := l.svcCtx.AuthInvite.Consume(l.ctx, &domauth.ConsumeInviteRequest{
|
||
|
|
TenantID: session.TenantID,
|
||
|
|
Code: session.InviteCode,
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
inviteCodeID = consumed.ID
|
||
|
|
}
|
||
|
|
|
||
|
|
memberDTO, err := l.svcCtx.MemberProvisioning.EnsureFromOIDC(l.ctx, &dommember.EnsureFromOIDCRequest{
|
||
|
|
TenantID: session.TenantID,
|
||
|
|
ZitadelSub: claims.Sub,
|
||
|
|
Email: claims.Email,
|
||
|
|
EmailVerified: claims.EmailVerified,
|
||
|
|
DisplayName: claims.Name,
|
||
|
|
Locale: firstNonEmpty(session.Language, claims.Locale),
|
||
|
|
})
|
||
|
|
if err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
|
||
|
|
if !isExisting {
|
||
|
|
if err := recordRegistrationMeta(l.ctx, l.svcCtx, session.TenantID, memberDTO.UID, inviteCodeID, session.AcceptTermsVersion, session.MarketingOptIn, authmetaenum.RegistrationChannelGoogle); err != nil {
|
||
|
|
return nil, err
|
||
|
|
}
|
||
|
|
}
|
||
|
|
|
||
|
|
return issueAuthToken(l.ctx, l.svcCtx, session.TenantID, memberDTO.UID)
|
||
|
|
}
|
||
|
|
|
||
|
|
func firstNonEmpty(values ...string) string {
|
||
|
|
for _, v := range values {
|
||
|
|
if v != "" {
|
||
|
|
return v
|
||
|
|
}
|
||
|
|
}
|
||
|
|
return ""
|
||
|
|
}
|