feat: add chat cache (#152)

This commit is contained in:
Buer
2024-04-16 10:36:18 +08:00
committed by GitHub
parent bbaa4eec4b
commit 3c7c13758b
23 changed files with 557 additions and 49 deletions

128
relay/util/cache.go Normal file
View File

@@ -0,0 +1,128 @@
package util
import (
"crypto/md5"
"encoding/hex"
"fmt"
"one-api/common"
"one-api/model"
"github.com/gin-gonic/gin"
)
type ChatCacheProps struct {
UserId int `json:"user_id"`
TokenId int `json:"token_id"`
ChannelID int `json:"channel_id"`
PromptTokens int `json:"prompt_tokens"`
CompletionTokens int `json:"completion_tokens"`
ModelName string `json:"model_name"`
Response string `json:"response"`
Hash string `json:"-"`
Cache bool `json:"-"`
Driver CacheDriver `json:"-"`
}
type CacheDriver interface {
Get(hash string, userId int) *ChatCacheProps
Set(hash string, props *ChatCacheProps, expire int64) error
}
func GetDebugList(userId int) ([]*ChatCacheProps, error) {
caches, err := model.GetChatCacheListByUserId(userId)
if err != nil {
return nil, err
}
var props []*ChatCacheProps
for _, cache := range caches {
prop, err := common.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
continue
}
props = append(props, &prop)
}
return props, nil
}
func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
props := &ChatCacheProps{
Cache: false,
}
if !allow {
return props
}
if common.ChatCacheEnabled && c.GetBool("chat_cache") {
props.Cache = true
}
if common.RedisEnabled {
props.Driver = &ChatCacheRedis{}
} else {
props.Driver = &ChatCacheDB{}
}
props.UserId = c.GetInt("id")
props.TokenId = c.GetInt("token_id")
return props
}
func (p *ChatCacheProps) SetHash(request any) {
if !p.needCache() || request == nil {
return
}
p.hash(common.Marshal(request))
}
func (p *ChatCacheProps) SetResponse(response any) {
if !p.needCache() || response == nil {
return
}
if str, ok := response.(string); ok {
p.Response += str
return
}
p.Response = common.Marshal(response)
}
func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
if !p.needCache() || p.Response == "" {
return nil
}
p.ChannelID = channelId
p.PromptTokens = promptTokens
p.CompletionTokens = completionTokens
p.ModelName = modelName
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
}
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
if !p.needCache() {
return nil
}
return p.Driver.Get(p.getHash(), p.UserId)
}
func (p *ChatCacheProps) needCache() bool {
return common.ChatCacheEnabled && p.Cache
}
func (p *ChatCacheProps) getHash() string {
return p.Hash
}
func (p *ChatCacheProps) hash(request string) {
hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
p.Hash = hex.EncodeToString(hash[:])
}

47
relay/util/cache_db.go Normal file
View File

@@ -0,0 +1,47 @@
package util
import (
"errors"
"one-api/common"
"one-api/model"
"time"
)
type ChatCacheDB struct{}
func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
cache, _ := model.GetChatCache(hash, userId)
if cache == nil {
return nil
}
props, err := common.UnmarshalString[ChatCacheProps](cache.Data)
if err != nil {
return nil
}
return &props
}
func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
return SetCacheDB(hash, props, expire)
}
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
data := common.Marshal(props)
if data == "" {
return errors.New("marshal error")
}
expire = expire * 60
expire += time.Now().Unix()
cache := &model.ChatCache{
Hash: hash,
UserId: props.UserId,
Data: data,
Expiration: expire,
}
return cache.Insert()
}

44
relay/util/cache_redis.go Normal file
View File

@@ -0,0 +1,44 @@
package util
import (
"errors"
"fmt"
"one-api/common"
"time"
)
type ChatCacheRedis struct{}
var chatCacheKey = "chat_cache"
func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
cache, err := common.RedisGet(r.getKey(hash, userId))
if err != nil {
return nil
}
props, err := common.UnmarshalString[ChatCacheProps](cache)
if err != nil {
return nil
}
return &props
}
func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
if !props.Cache {
return nil
}
data := common.Marshal(&props)
if data == "" {
return errors.New("marshal error")
}
return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
}
func (r *ChatCacheRedis) getKey(hash string, userId int) string {
return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
}