diff --git a/handler/admin/user.go b/handler/admin/user.go index 92817de..8d2ea5f 100644 --- a/handler/admin/user.go +++ b/handler/admin/user.go @@ -128,5 +128,5 @@ func (u *UserHandler) UpdateMFA(ctx *gin.Context) (interface{}, error) { if mfaParam.MFAType == nil { return nil, xerr.WithStatus(err, xerr.StatusBadRequest).WithMsg("parameter error") } - return u.UserService.UpdateMFA(ctx, mfaParam.MFAKey, *mfaParam.MFAType, mfaParam.AuthCode), nil + return nil, u.UserService.UpdateMFA(ctx, mfaParam.MFAKey, *mfaParam.MFAType, mfaParam.AuthCode) } diff --git a/service/impl/user.go b/service/impl/user.go index 96b0742..a38f6aa 100644 --- a/service/impl/user.go +++ b/service/impl/user.go @@ -84,18 +84,26 @@ func (u *userServiceImpl) Update(ctx context.Context, userParam *param.User) (*e } func (u *userServiceImpl) UpdateMFA(ctx context.Context, mfaKey string, mfaType consts.MFAType, mfaCode string) error { - if mfaType == consts.MFATFATotp { + user, err := MustGetAuthorizedUser(ctx) + if err != nil { + return err + } + + switch mfaType { + case consts.MFATFATotp: ok := u.TwoFactorMFAService.ValidateTFACode(mfaKey, mfaCode) if !ok { return xerr.WithStatus(nil, xerr.StatusBadRequest).WithMsg("Invalid Validation Code") } - } else if mfaType != consts.MFANone { + case consts.MFANone: + ok := u.TwoFactorMFAService.ValidateTFACode(user.MfaKey, mfaCode) + if !ok { + return xerr.WithStatus(nil, xerr.StatusBadRequest).WithMsg("Invalid Validation Code") + } + default: return xerr.WithMsg(nil, "Not supported authentication").WithStatus(xerr.StatusBadRequest) } - user, err := MustGetAuthorizedUser(ctx) - if err != nil { - return err - } + userDal := dal.GetQueryByCtx(ctx).User updateResult, err := userDal.WithContext(ctx).Where(userDal.ID.Eq(user.ID)).UpdateSimple( userDal.MfaKey.Value(mfaKey),