mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
✨ feat: add chat cache (#152)
This commit is contained in:
128
relay/util/cache.go
Normal file
128
relay/util/cache.go
Normal file
@@ -0,0 +1,128 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"crypto/md5"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type ChatCacheProps struct {
|
||||
UserId int `json:"user_id"`
|
||||
TokenId int `json:"token_id"`
|
||||
ChannelID int `json:"channel_id"`
|
||||
PromptTokens int `json:"prompt_tokens"`
|
||||
CompletionTokens int `json:"completion_tokens"`
|
||||
ModelName string `json:"model_name"`
|
||||
Response string `json:"response"`
|
||||
|
||||
Hash string `json:"-"`
|
||||
Cache bool `json:"-"`
|
||||
Driver CacheDriver `json:"-"`
|
||||
}
|
||||
|
||||
type CacheDriver interface {
|
||||
Get(hash string, userId int) *ChatCacheProps
|
||||
Set(hash string, props *ChatCacheProps, expire int64) error
|
||||
}
|
||||
|
||||
func GetDebugList(userId int) ([]*ChatCacheProps, error) {
|
||||
caches, err := model.GetChatCacheListByUserId(userId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var props []*ChatCacheProps
|
||||
for _, cache := range caches {
|
||||
prop, err := common.UnmarshalString[ChatCacheProps](cache.Data)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
props = append(props, &prop)
|
||||
}
|
||||
|
||||
return props, nil
|
||||
}
|
||||
|
||||
func NewChatCacheProps(c *gin.Context, allow bool) *ChatCacheProps {
|
||||
props := &ChatCacheProps{
|
||||
Cache: false,
|
||||
}
|
||||
|
||||
if !allow {
|
||||
return props
|
||||
}
|
||||
|
||||
if common.ChatCacheEnabled && c.GetBool("chat_cache") {
|
||||
props.Cache = true
|
||||
}
|
||||
|
||||
if common.RedisEnabled {
|
||||
props.Driver = &ChatCacheRedis{}
|
||||
} else {
|
||||
props.Driver = &ChatCacheDB{}
|
||||
}
|
||||
|
||||
props.UserId = c.GetInt("id")
|
||||
props.TokenId = c.GetInt("token_id")
|
||||
|
||||
return props
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) SetHash(request any) {
|
||||
if !p.needCache() || request == nil {
|
||||
return
|
||||
}
|
||||
|
||||
p.hash(common.Marshal(request))
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) SetResponse(response any) {
|
||||
if !p.needCache() || response == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if str, ok := response.(string); ok {
|
||||
p.Response += str
|
||||
return
|
||||
}
|
||||
|
||||
p.Response = common.Marshal(response)
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) StoreCache(channelId, promptTokens, completionTokens int, modelName string) error {
|
||||
if !p.needCache() || p.Response == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
p.ChannelID = channelId
|
||||
p.PromptTokens = promptTokens
|
||||
p.CompletionTokens = completionTokens
|
||||
p.ModelName = modelName
|
||||
|
||||
return p.Driver.Set(p.getHash(), p, int64(common.ChatCacheExpireMinute))
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) GetCache() *ChatCacheProps {
|
||||
if !p.needCache() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return p.Driver.Get(p.getHash(), p.UserId)
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) needCache() bool {
|
||||
return common.ChatCacheEnabled && p.Cache
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) getHash() string {
|
||||
return p.Hash
|
||||
}
|
||||
|
||||
func (p *ChatCacheProps) hash(request string) {
|
||||
hash := md5.Sum([]byte(fmt.Sprintf("%d-%d-%s", p.UserId, p.TokenId, request)))
|
||||
p.Hash = hex.EncodeToString(hash[:])
|
||||
}
|
||||
47
relay/util/cache_db.go
Normal file
47
relay/util/cache_db.go
Normal file
@@ -0,0 +1,47 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatCacheDB struct{}
|
||||
|
||||
func (db *ChatCacheDB) Get(hash string, userId int) *ChatCacheProps {
|
||||
cache, _ := model.GetChatCache(hash, userId)
|
||||
if cache == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
props, err := common.UnmarshalString[ChatCacheProps](cache.Data)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &props
|
||||
}
|
||||
|
||||
func (db *ChatCacheDB) Set(hash string, props *ChatCacheProps, expire int64) error {
|
||||
return SetCacheDB(hash, props, expire)
|
||||
}
|
||||
|
||||
func SetCacheDB(hash string, props *ChatCacheProps, expire int64) error {
|
||||
data := common.Marshal(props)
|
||||
if data == "" {
|
||||
return errors.New("marshal error")
|
||||
}
|
||||
|
||||
expire = expire * 60
|
||||
expire += time.Now().Unix()
|
||||
|
||||
cache := &model.ChatCache{
|
||||
Hash: hash,
|
||||
UserId: props.UserId,
|
||||
Data: data,
|
||||
Expiration: expire,
|
||||
}
|
||||
|
||||
return cache.Insert()
|
||||
}
|
||||
44
relay/util/cache_redis.go
Normal file
44
relay/util/cache_redis.go
Normal file
@@ -0,0 +1,44 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"one-api/common"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ChatCacheRedis struct{}
|
||||
|
||||
var chatCacheKey = "chat_cache"
|
||||
|
||||
func (r *ChatCacheRedis) Get(hash string, userId int) *ChatCacheProps {
|
||||
cache, err := common.RedisGet(r.getKey(hash, userId))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
props, err := common.UnmarshalString[ChatCacheProps](cache)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &props
|
||||
}
|
||||
|
||||
func (r *ChatCacheRedis) Set(hash string, props *ChatCacheProps, expire int64) error {
|
||||
|
||||
if !props.Cache {
|
||||
return nil
|
||||
}
|
||||
|
||||
data := common.Marshal(&props)
|
||||
if data == "" {
|
||||
return errors.New("marshal error")
|
||||
}
|
||||
|
||||
return common.RedisSet(r.getKey(hash, props.UserId), data, time.Duration(expire)*time.Minute)
|
||||
}
|
||||
|
||||
func (r *ChatCacheRedis) getKey(hash string, userId int) string {
|
||||
return fmt.Sprintf("%s:%d:%s", chatCacheKey, userId, hash)
|
||||
}
|
||||
Reference in New Issue
Block a user