Skip to content

Commit c7a934e

Browse files
authored
Merge pull request #3341 from openimsdk/cherry-pick-56c5c1f
fix: delete token by correct platformID && feat: adminToken can be re… [Created by @icey-yu from #3313]
2 parents 8f61586 + 78aaf6a commit c7a934e

File tree

7 files changed

+492
-45
lines changed

7 files changed

+492
-45
lines changed

internal/rpc/auth/auth.go

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,15 +140,17 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
140140
if err != nil {
141141
return nil, err
142142
}
143-
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
144-
if isAdmin {
145-
return claims, nil
146-
}
147143
m, err := s.authDatabase.GetTokensWithoutError(ctx, claims.UserID, claims.PlatformID)
148144
if err != nil {
149145
return nil, err
150146
}
151147
if len(m) == 0 {
148+
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
149+
if isAdmin {
150+
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
151+
return claims, nil
152+
}
153+
}
152154
return nil, servererrs.ErrTokenNotExist.Wrap()
153155
}
154156
if v, ok := m[tokensString]; ok {
@@ -160,6 +162,13 @@ func (s *authServer) parseToken(ctx context.Context, tokensString string) (claim
160162
default:
161163
return nil, errs.Wrap(errs.ErrTokenUnknown)
162164
}
165+
} else {
166+
isAdmin := authverify.IsManagerUserID(claims.UserID, s.config.Share.IMAdminUserID)
167+
if isAdmin {
168+
if err = s.authDatabase.GetTemporaryTokensWithoutError(ctx, claims.UserID, claims.PlatformID, tokensString); err == nil {
169+
return claims, nil
170+
}
171+
}
163172
}
164173
return nil, servererrs.ErrTokenNotExist.Wrap()
165174
}

pkg/common/storage/cache/cachekey/token.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package cachekey
22

33
import (
4-
"github.com/openimsdk/protocol/constant"
54
"strings"
5+
6+
"github.com/openimsdk/protocol/constant"
67
)
78

89
const (
@@ -13,6 +14,10 @@ func GetTokenKey(userID string, platformID int) string {
1314
return UidPidToken + userID + ":" + constant.PlatformIDToName(platformID)
1415
}
1516

17+
func GetTemporaryTokenKey(userID string, platformID int, token string) string {
18+
return UidPidToken + ":TEMPORARY:" + userID + ":" + constant.PlatformIDToName(platformID) + ":" + token
19+
}
20+
1621
func GetAllPlatformTokenKey(userID string) []string {
1722
res := make([]string, len(constant.PlatformID2Name))
1823
for k := range constant.PlatformID2Name {
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
package mcache
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strconv"
7+
"strings"
8+
"time"
9+
10+
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
11+
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
12+
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/database"
13+
"github.com/openimsdk/tools/errs"
14+
"github.com/openimsdk/tools/log"
15+
)
16+
17+
func NewTokenCacheModel(cache database.Cache, accessExpire int64) cache.TokenModel {
18+
c := &tokenCache{cache: cache}
19+
c.accessExpire = c.getExpireTime(accessExpire)
20+
return c
21+
}
22+
23+
type tokenCache struct {
24+
cache database.Cache
25+
accessExpire time.Duration
26+
}
27+
28+
func (x *tokenCache) getTokenKey(userID string, platformID int, token string) string {
29+
return cachekey.GetTokenKey(userID, platformID) + ":" + token
30+
}
31+
32+
func (x *tokenCache) SetTokenFlag(ctx context.Context, userID string, platformID int, token string, flag int) error {
33+
return x.cache.Set(ctx, x.getTokenKey(userID, platformID, token), strconv.Itoa(flag), x.accessExpire)
34+
}
35+
36+
// SetTokenFlagEx set token and flag with expire time
37+
func (x *tokenCache) SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error {
38+
return x.SetTokenFlag(ctx, userID, platformID, token, flag)
39+
}
40+
41+
func (x *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error) {
42+
prefix := x.getTokenKey(userID, platformID, "")
43+
m, err := x.cache.Prefix(ctx, prefix)
44+
if err != nil {
45+
return nil, errs.Wrap(err)
46+
}
47+
mm := make(map[string]int)
48+
for k, v := range m {
49+
state, err := strconv.Atoi(v)
50+
if err != nil {
51+
log.ZError(ctx, "token value is not int", err, "value", v, "userID", userID, "platformID", platformID)
52+
continue
53+
}
54+
mm[strings.TrimPrefix(k, prefix)] = state
55+
}
56+
return mm, nil
57+
}
58+
59+
func (x *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
60+
key := cachekey.GetTemporaryTokenKey(userID, platformID, token)
61+
if _, err := x.cache.Get(ctx, []string{key}); err != nil {
62+
return err
63+
}
64+
return nil
65+
}
66+
67+
func (x *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
68+
prefix := cachekey.UidPidToken + userID + ":"
69+
tokens, err := x.cache.Prefix(ctx, prefix)
70+
if err != nil {
71+
return nil, err
72+
}
73+
res := make(map[int]map[string]int)
74+
for key, flagStr := range tokens {
75+
flag, err := strconv.Atoi(flagStr)
76+
if err != nil {
77+
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
78+
continue
79+
}
80+
arr := strings.SplitN(strings.TrimPrefix(key, prefix), ":", 2)
81+
if len(arr) != 2 {
82+
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
83+
continue
84+
}
85+
platformID, err := strconv.Atoi(arr[0])
86+
if err != nil {
87+
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
88+
continue
89+
}
90+
token := arr[1]
91+
if token == "" {
92+
log.ZError(ctx, "token value is not int", err, "key", key, "value", flagStr, "userID", userID)
93+
continue
94+
}
95+
tk, ok := res[platformID]
96+
if !ok {
97+
tk = make(map[string]int)
98+
res[platformID] = tk
99+
}
100+
tk[token] = flag
101+
}
102+
return res, nil
103+
}
104+
105+
func (x *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error {
106+
for token, flag := range m {
107+
err := x.SetTokenFlag(ctx, userID, platformID, token, flag)
108+
if err != nil {
109+
return err
110+
}
111+
}
112+
return nil
113+
}
114+
115+
func (x *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
116+
for prefix, tokenFlag := range tokens {
117+
for token, flag := range tokenFlag {
118+
flagStr := fmt.Sprintf("%v", flag)
119+
if err := x.cache.Set(ctx, prefix+":"+token, flagStr, x.accessExpire); err != nil {
120+
return err
121+
}
122+
}
123+
}
124+
return nil
125+
}
126+
127+
func (x *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error {
128+
keys := make([]string, 0, len(fields))
129+
for _, token := range fields {
130+
keys = append(keys, x.getTokenKey(userID, platformID, token))
131+
}
132+
return x.cache.Del(ctx, keys)
133+
}
134+
135+
func (x *tokenCache) getExpireTime(t int64) time.Duration {
136+
return time.Hour * 24 * time.Duration(t)
137+
}
138+
139+
func (x *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
140+
keys := make([]string, 0, len(tokens))
141+
for platformID, ts := range tokens {
142+
for _, t := range ts {
143+
keys = append(keys, x.getTokenKey(userID, platformID, t))
144+
}
145+
}
146+
return x.cache.Del(ctx, keys)
147+
}
148+
149+
func (x *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
150+
keys := make([]string, 0, len(fields))
151+
for _, f := range fields {
152+
keys = append(keys, x.getTokenKey(userID, platformID, f))
153+
}
154+
if err := x.cache.Del(ctx, keys); err != nil {
155+
return err
156+
}
157+
158+
for _, f := range fields {
159+
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
160+
if err := x.cache.Set(ctx, k, "", time.Minute*5); err != nil {
161+
return errs.Wrap(err)
162+
}
163+
}
164+
165+
return nil
166+
}

pkg/common/storage/cache/redis/token.go

Lines changed: 66 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache"
1010
"github.com/openimsdk/open-im-server/v3/pkg/common/storage/cache/cachekey"
1111
"github.com/openimsdk/tools/errs"
12+
"github.com/openimsdk/tools/utils/datautil"
1213
"github.com/redis/go-redis/v9"
1314
)
1415

@@ -55,6 +56,14 @@ func (c *tokenCache) GetTokensWithoutError(ctx context.Context, userID string, p
5556
return mm, nil
5657
}
5758

59+
func (c *tokenCache) HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error {
60+
err := c.rdb.Get(ctx, cachekey.GetTemporaryTokenKey(userID, platformID, token)).Err()
61+
if err != nil {
62+
return errs.Wrap(err)
63+
}
64+
return nil
65+
}
66+
5867
func (c *tokenCache) GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error) {
5968
var (
6069
res = make(map[int]map[string]int)
@@ -101,13 +110,19 @@ func (c *tokenCache) SetTokenMapByUidPid(ctx context.Context, userID string, pla
101110
}
102111

103112
func (c *tokenCache) BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error {
104-
pipe := c.rdb.Pipeline()
105-
for k, v := range tokens {
106-
pipe.HSet(ctx, k, v)
107-
}
108-
_, err := pipe.Exec(ctx)
109-
if err != nil {
110-
return errs.Wrap(err)
113+
keys := datautil.Keys(tokens)
114+
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
115+
pipe := c.rdb.Pipeline()
116+
for k, v := range tokens {
117+
pipe.HSet(ctx, k, v)
118+
}
119+
_, err := pipe.Exec(ctx)
120+
if err != nil {
121+
return errs.Wrap(err)
122+
}
123+
return nil
124+
}); err != nil {
125+
return err
111126
}
112127
return nil
113128
}
@@ -119,3 +134,47 @@ func (c *tokenCache) DeleteTokenByUidPid(ctx context.Context, userID string, pla
119134
func (c *tokenCache) getExpireTime(t int64) time.Duration {
120135
return time.Hour * 24 * time.Duration(t)
121136
}
137+
138+
// DeleteTokenByTokenMap tokens key is platformID, value is token slice
139+
func (c *tokenCache) DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error {
140+
var (
141+
keys = make([]string, 0, len(tokens))
142+
keyMap = make(map[string][]string)
143+
)
144+
for k, v := range tokens {
145+
k1 := cachekey.GetTokenKey(userID, k)
146+
keys = append(keys, k1)
147+
keyMap[k1] = v
148+
}
149+
150+
if err := ProcessKeysBySlot(ctx, c.rdb, keys, func(ctx context.Context, slot int64, keys []string) error {
151+
pipe := c.rdb.Pipeline()
152+
for k, v := range tokens {
153+
pipe.HDel(ctx, cachekey.GetTokenKey(userID, k), v...)
154+
}
155+
_, err := pipe.Exec(ctx)
156+
if err != nil {
157+
return errs.Wrap(err)
158+
}
159+
return nil
160+
}); err != nil {
161+
return err
162+
}
163+
164+
return nil
165+
}
166+
167+
func (c *tokenCache) DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error {
168+
key := cachekey.GetTokenKey(userID, platformID)
169+
if err := c.rdb.HDel(ctx, key, fields...).Err(); err != nil {
170+
return errs.Wrap(err)
171+
}
172+
for _, f := range fields {
173+
k := cachekey.GetTemporaryTokenKey(userID, platformID, f)
174+
if err := c.rdb.Set(ctx, k, "", time.Minute*5).Err(); err != nil {
175+
return errs.Wrap(err)
176+
}
177+
}
178+
179+
return nil
180+
}

pkg/common/storage/cache/token.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,11 @@ type TokenModel interface {
99
// SetTokenFlagEx set token and flag with expire time
1010
SetTokenFlagEx(ctx context.Context, userID string, platformID int, token string, flag int) error
1111
GetTokensWithoutError(ctx context.Context, userID string, platformID int) (map[string]int, error)
12+
HasTemporaryToken(ctx context.Context, userID string, platformID int, token string) error
1213
GetAllTokensWithoutError(ctx context.Context, userID string) (map[int]map[string]int, error)
1314
SetTokenMapByUidPid(ctx context.Context, userID string, platformID int, m map[string]int) error
1415
BatchSetTokenMapByUidPid(ctx context.Context, tokens map[string]map[string]any) error
1516
DeleteTokenByUidPid(ctx context.Context, userID string, platformID int, fields []string) error
17+
DeleteTokenByTokenMap(ctx context.Context, userID string, tokens map[int][]string) error
18+
DeleteAndSetTemporary(ctx context.Context, userID string, platformID int, fields []string) error
1619
}

0 commit comments

Comments
 (0)