支持按次收费的 OpenAI 实时语音通话功能

This commit is contained in:
RockYang
2024-12-20 18:21:54 +08:00
24 changed files with 1956 additions and 821 deletions

View File

@@ -8,6 +8,8 @@ package admin
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"encoding/csv"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/handler"
@@ -35,12 +37,10 @@ func (h *RedeemHandler) List(c *gin.Context) {
session := h.DB.Session(&gorm.Session{})
if code != "" {
session.Where("code LIKE ?", "%"+code+"%")
session = session.Where("code LIKE ?", "%"+code+"%")
}
if status == 0 {
session.Where("redeem_at = ?", 0)
} else if status == 1 {
session.Where("redeem_at > ?", 0)
if status >= 0 {
session = session.Where("redeemed_at", status)
}
var total int64
@@ -80,6 +80,65 @@ func (h *RedeemHandler) List(c *gin.Context) {
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
}
// Export 导出 CVS 文件
func (h *RedeemHandler) Export(c *gin.Context) {
var data struct {
Status int `json:"status"`
Ids []int `json:"ids"`
}
if err := c.ShouldBindJSON(&data); err != nil {
resp.ERROR(c, types.InvalidArgs)
}
session := h.DB.Session(&gorm.Session{})
if data.Status >= 0 {
session = session.Where("redeemed_at", data.Status)
}
if len(data.Ids) > 0 {
session = session.Where("id IN ?", data.Ids)
}
var items []model.Redeem
err := session.Order("id DESC").Find(&items).Error
if err != nil {
resp.ERROR(c, err.Error())
return
}
// 设置响应头,告诉浏览器这是一个附件,需要下载
c.Header("Content-Disposition", "attachment; filename=output.csv")
c.Header("Content-Type", "text/csv")
// 创建一个 CSV writer
writer := csv.NewWriter(c.Writer)
// 写入 CSV 文件的标题行
headers := []string{"名称", "兑换码", "算力", "创建时间"}
if err := writer.Write(headers); err != nil {
resp.ERROR(c, err.Error())
return
}
// 写入数据行
records := make([][]string, 0)
for _, item := range items {
records = append(records, []string{item.Name, item.Code, fmt.Sprintf("%d", item.Power), item.CreatedAt.Format("2006-01-02 15:04:05")})
}
for _, record := range records {
if err := writer.Write(record); err != nil {
resp.ERROR(c, err.Error())
return
}
}
// 确保所有数据都已写入响应
writer.Flush()
if err := writer.Error(); err != nil {
resp.ERROR(c, err.Error())
return
}
}
func (h *RedeemHandler) Create(c *gin.Context) {
var data struct {
Name string `json:"name"`

View File

@@ -1,15 +1,24 @@
package handler
import (
"encoding/json"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/service"
"geekai/store/model"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"gorm.io/gorm"
"geekai/utils"
"geekai/utils/resp"
"io"
"net/http"
"regexp"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"github.com/imroc/req/v3"
"gorm.io/gorm"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
@@ -23,10 +32,11 @@ import (
type RealtimeHandler struct {
BaseHandler
userService *service.UserService
}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *service.UserService) *RealtimeHandler {
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
}
func (h *RealtimeHandler) Connection(c *gin.Context) {
@@ -126,3 +136,73 @@ func sendError(ws *websocket.Conn, message string) {
logger.Error(err)
}
}
// OpenAI 实时语音对话,一次性对话
func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
var apiKey model.ApiKey
err := h.DB.Session(&gorm.Session{}).Where("type", "realtime").Where("enabled", true).First(&apiKey).Error
if err != nil {
resp.ERROR(c, fmt.Sprintf("error with fetch OpenAI API KEY%v", err))
}
var response utils.OpenAIResponse
client := req.C()
if len(apiKey.ProxyURL) > 5 {
client.SetProxyURL(apiKey.ApiURL)
}
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, "advanced-voice")
r, err := client.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
SetBody(types.ApiRequest{
Model: "advanced-voice",
Temperature: 0.9,
MaxTokens: 1024,
Stream: false,
Messages: []interface{}{types.Message{
Role: "user",
Content: "实时语音通话",
}},
}).Post(apiURL)
if err != nil {
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败%v", err))
return
}
if r.IsErrorState() {
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败%v", r.Status))
return
}
body, _ := io.ReadAll(r.Body)
err = json.Unmarshal(body, &response)
if err != nil {
resp.ERROR(c, fmt.Sprintf("解析API数据失败%v, %s", err, string(body)))
}
// 更新 API KEY 的最后使用时间
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
// 扣减算力
userId := h.GetLoginUserId(c)
err = h.userService.DecreasePower(int(userId), h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
Type: types.PowerConsume,
Model: "advanced-voice",
Remark: "实时语音通话",
})
if err != nil {
resp.ERROR(c, err.Error())
return
}
logger.Infof("Response: %v", response.Choices[0].Message.Content)
// 提取链接
re := regexp.MustCompile(`\[(.*?)\]\((.*?)\)`)
links := re.FindAllStringSubmatch(response.Choices[0].Message.Content, -1)
var url = ""
if len(links) > 0 {
url = links[0][2]
}
resp.SUCCESS(c, url)
}

View File

@@ -215,8 +215,8 @@ func (h *SunoHandler) Remove(c *gin.Context) {
return
}
// 只有失败或者超时的任务才能删除
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
// 只有失败或者已完成的任务可以删除
if !(job.Progress == service.FailTaskProgress || job.Progress == 100) {
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
return
}