支持按次收费的 OpenAI 实时语音通话功能

This commit is contained in:
RockYang
2024-12-20 18:21:54 +08:00
24 changed files with 1956 additions and 821 deletions

View File

@@ -149,6 +149,7 @@ type SystemConfig struct {
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
@@ -166,4 +167,5 @@ type SystemConfig struct {
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id
}

View File

@@ -8,6 +8,8 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/csv"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
@@ -35,12 +37,10 @@ func (h *RedeemHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{})
if code != "" {
session.Where("code LIKE ?", "%"+code+"%")
session = session.Where("code LIKE ?", "%"+code+"%")
}
if status == 0 {
session.Where("redeem_at = ?", 0)
} else if status == 1 {
session.Where("redeem_at > ?", 0)
if status >= 0 {
session = session.Where("redeemed_at", status)
}
var total int64
@@ -80,6 +80,65 @@ func (h *RedeemHandler) List(c *gin.Context) {
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
// Export 导出 CVS 文件
func (h *RedeemHandler) Export(c *gin.Context) {
var data struct {
Status int `json:"status"`
Ids []int `json:"ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
}
session := h.DB.Session(&gorm.Session{})
if data.Status >= 0 {
session = session.Where("redeemed_at", data.Status)
}
if len(data.Ids) > 0 {
session = session.Where("id IN ?", data.Ids)
}
var items []model.Redeem
err := session.Order("id DESC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 设置响应头,告诉浏览器这是一个附件,需要下载
c.Header("Content-Disposition", "attachment; filename=output.csv")
c.Header("Content-Type", "text/csv")
// 创建一个 CSV writer
writer := csv.NewWriter(c.Writer)
// 写入 CSV 文件的标题行
headers := []string{"名称", "兑换码", "算力", "创建时间"}
if err := writer.Write(headers); err != nil {
resp.ERROR(c, err.Error())
return
}
// 写入数据行
records := make([][]string, 0)
for _, item := range items {
records = append(records, []string{item.Name, item.Code, fmt.Sprintf("%d", item.Power), item.CreatedAt.Format("2006-01-02 15:04:05")})
}
for _, record := range records {
if err := writer.Write(record); err != nil {
resp.ERROR(c, err.Error())
return
}
}
// 确保所有数据都已写入响应
writer.Flush()
if err := writer.Error(); err != nil {
resp.ERROR(c, err.Error())
return
}
}
func (h *RedeemHandler) Create(c *gin.Context) {
var data struct {
Name string `json:"name"`

View File

@@ -1,15 +1,24 @@
package handler
import (
"encoding/json"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"geekai/utils"
"geekai/utils/resp"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@@ -23,10 +32,11 @@ import (
type RealtimeHandler struct {
BaseHandler
userService *service.UserService
}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *service.UserService) *RealtimeHandler {
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
}
func (h *RealtimeHandler) Connection(c *gin.Context) {
@@ -126,3 +136,73 @@ func sendError(ws *websocket.Conn, message string) {
logger.Error(err)
}
}
// OpenAI 实时语音对话,一次性对话
func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
var apiKey model.ApiKey
err := h.DB.Session(&gorm.Session{}).Where("type", "realtime").Where("enabled", true).First(&apiKey).Error
if err != nil {
resp.ERROR(c, fmt.Sprintf("error with fetch OpenAI API KEY%v", err))
}
var response utils.OpenAIResponse
client := req.C()
if len(apiKey.ProxyURL) > 5 {
client.SetProxyURL(apiKey.ApiURL)
}
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, "advanced-voice")
r, err := client.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(types.ApiRequest{
Model: "advanced-voice",
Temperature: 0.9,
MaxTokens: 1024,
Stream: false,
Messages: []interface{}{types.Message{
Role: "user",
Content: "实时语音通话",
}},
}).Post(apiURL)
if err != nil {
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败%v", err))
return
}
if r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败%v", r.Status))
return
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &response)
if err != nil {
resp.ERROR(c, fmt.Sprintf("解析API数据失败%v, %s", err, string(body)))
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 扣减算力
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(int(userId), h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
Type: types.PowerConsume,
Model: "advanced-voice",
Remark: "实时语音通话",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
logger.Infof("Response: %v", response.Choices[0].Message.Content)
// 提取链接
re := regexp.MustCompile(`\[(.*?)\]\((.*?)\)`)
links := re.FindAllStringSubmatch(response.Choices[0].Message.Content, -1)
var url = ""
if len(links) > 0 {
url = links[0][2]
}
resp.SUCCESS(c, url)
}

View File

@@ -215,8 +215,8 @@ func (h *SunoHandler) Remove(c *gin.Context) {
return
}
// 只有失败或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
// 只有失败或者已完成的任务可以删除
if !(job.Progress == service.FailTaskProgress || job.Progress == 100) {
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
return
}

View File

@@ -350,6 +350,7 @@ func main() {
group.POST("create", h.Create)
group.POST("set", h.Set)
group.GET("remove", h.Remove)
group.POST("export", h.Export)
}),
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
group := s.Engine.Group("/api/admin/dashboard/")
@@ -564,6 +565,7 @@ func main() {
fx.Provide(handler.NewRealtimeHandler),
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
s.Engine.Any("/api/realtime", h.Connection)
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
}),
)
// 启动应用程序

View File

@@ -130,10 +130,13 @@ type LumaRespVo struct {
Id string `json:"id"`
Prompt string `json:"prompt"`
State string `json:"state"`
CreatedAt time.Time `json:"created_at"`
QueueState interface{} `json:"queue_state"`
CreatedAt string `json:"created_at"`
Video interface{} `json:"video"`
VideoRaw interface{} `json:"video_raw"`
Liked interface{} `json:"liked"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Thumbnail interface{} `json:"thumbnail"`
Channel string `json:"channel,omitempty"`
}
@@ -234,7 +237,7 @@ func (s *Service) DownloadFiles() {
continue
}
}
logger.Info("download no water video success: %s", videoURL)
logger.Infof("download no water video success: %s", videoURL)
v.VideoURL = videoURL
v.Progress = 100
s.db.Updates(&v)
@@ -275,6 +278,7 @@ func (s *Service) SyncTaskProgress() {
"water_url": task.Video.Url,
"raw_data": utils.JsonEncode(task),
"prompt_ext": task.Prompt,
"cover_url": task.Thumbnail.Url,
}
if task.Video.DownloadUrl != "" {
data["video_url"] = task.Video.DownloadUrl
@@ -315,11 +319,28 @@ type LumaTaskVo struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
Thumbnail string `json:"thumbnail"`
DownloadUrl string `json:"download_url"`
} `json:"video"`
Prompt string `json:"prompt"`
CreatedAt time.Time `json:"created_at"`
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
Prompt string `json:"prompt"`
UserId string `json:"user_id"`
BatchId string `json:"batch_id"`
Thumbnail struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"thumbnail"`
VideoRaw struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"video_raw"`
CreatedAt string `json:"created_at"`
LastFrame struct {
Url string `json:"url"`
Width int `json:"width"`
Height int `json:"height"`
} `json:"last_frame"`
}
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {

View File

@@ -12,11 +12,12 @@ import (
"fmt"
"geekai/core/types"
"geekai/store/model"
"io"
"time"
"github.com/imroc/req/v3"
"github.com/pkoukk/tiktoken-go"
"gorm.io/gorm"
"io"
"time"
)
func CalcTokens(text string, model string) (int, error) {
@@ -33,7 +34,7 @@ func CalcTokens(text string, model string) (int, error) {
return len(token), nil
}
type apiRes struct {
type OpenAIResponse struct {
Model string `json:"model"`
Choices []struct {
Index int `json:"index"`
@@ -70,7 +71,7 @@ func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelId int) (string
return "", fmt.Errorf("error with fetch OpenAI API KEY%v", err)
}
var response apiRes
var response OpenAIResponse
client := req.C()
if len(apiKey.ProxyURL) > 5 {
client.SetProxyURL(apiKey.ApiURL)