mirror of
https://github.com/yangjian102621/geekai.git
synced 2025-11-08 10:13:44 +08:00
支持按次收费的 OpenAI 实时语音通话功能
This commit is contained in:
@@ -149,6 +149,7 @@ type SystemConfig struct {
|
||||
DallPower int `json:"dall_power,omitempty"` // DALL-E-3 绘图消耗算力
|
||||
SunoPower int `json:"suno_power,omitempty"` // Suno 生成歌曲消耗算力
|
||||
LumaPower int `json:"luma_power,omitempty"` // Luma 生成视频消耗算力
|
||||
AdvanceVoicePower int `json:"advance_voice_power,omitempty"` // 高级语音对话消耗算力
|
||||
|
||||
WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
|
||||
|
||||
@@ -166,4 +167,5 @@ type SystemConfig struct {
|
||||
EnabledVerify bool `json:"enabled_verify"` // 是否启用验证码
|
||||
EmailWhiteList []string `json:"email_white_list"` // 邮箱白名单列表
|
||||
TranslateModelId int `json:"translate_model_id"` // 用来做提示词翻译的大模型 id
|
||||
|
||||
}
|
||||
|
||||
@@ -8,6 +8,8 @@ package admin
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
|
||||
import (
|
||||
"encoding/csv"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/handler"
|
||||
@@ -35,12 +37,10 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if code != "" {
|
||||
session.Where("code LIKE ?", "%"+code+"%")
|
||||
session = session.Where("code LIKE ?", "%"+code+"%")
|
||||
}
|
||||
if status == 0 {
|
||||
session.Where("redeem_at = ?", 0)
|
||||
} else if status == 1 {
|
||||
session.Where("redeem_at > ?", 0)
|
||||
if status >= 0 {
|
||||
session = session.Where("redeemed_at", status)
|
||||
}
|
||||
|
||||
var total int64
|
||||
@@ -80,6 +80,65 @@ func (h *RedeemHandler) List(c *gin.Context) {
|
||||
resp.SUCCESS(c, vo.NewPage(total, page, pageSize, items))
|
||||
}
|
||||
|
||||
// Export 导出 CVS 文件
|
||||
func (h *RedeemHandler) Export(c *gin.Context) {
|
||||
var data struct {
|
||||
Status int `json:"status"`
|
||||
Ids []int `json:"ids"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&data); err != nil {
|
||||
resp.ERROR(c, types.InvalidArgs)
|
||||
}
|
||||
|
||||
session := h.DB.Session(&gorm.Session{})
|
||||
if data.Status >= 0 {
|
||||
session = session.Where("redeemed_at", data.Status)
|
||||
}
|
||||
if len(data.Ids) > 0 {
|
||||
session = session.Where("id IN ?", data.Ids)
|
||||
}
|
||||
|
||||
var items []model.Redeem
|
||||
err := session.Order("id DESC").Find(&items).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 设置响应头,告诉浏览器这是一个附件,需要下载
|
||||
c.Header("Content-Disposition", "attachment; filename=output.csv")
|
||||
c.Header("Content-Type", "text/csv")
|
||||
|
||||
// 创建一个 CSV writer
|
||||
writer := csv.NewWriter(c.Writer)
|
||||
|
||||
// 写入 CSV 文件的标题行
|
||||
headers := []string{"名称", "兑换码", "算力", "创建时间"}
|
||||
if err := writer.Write(headers); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
// 写入数据行
|
||||
records := make([][]string, 0)
|
||||
for _, item := range items {
|
||||
records = append(records, []string{item.Name, item.Code, fmt.Sprintf("%d", item.Power), item.CreatedAt.Format("2006-01-02 15:04:05")})
|
||||
}
|
||||
for _, record := range records {
|
||||
if err := writer.Write(record); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// 确保所有数据都已写入响应
|
||||
writer.Flush()
|
||||
if err := writer.Error(); err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (h *RedeemHandler) Create(c *gin.Context) {
|
||||
var data struct {
|
||||
Name string `json:"name"`
|
||||
|
||||
@@ -1,15 +1,24 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"geekai/core"
|
||||
"geekai/core/types"
|
||||
"geekai/service"
|
||||
"geekai/store/model"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"gorm.io/gorm"
|
||||
"geekai/utils"
|
||||
"geekai/utils/resp"
|
||||
"io"
|
||||
"net/http"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/imroc/req/v3"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
|
||||
@@ -23,10 +32,11 @@ import (
|
||||
|
||||
type RealtimeHandler struct {
|
||||
BaseHandler
|
||||
userService *service.UserService
|
||||
}
|
||||
|
||||
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB) *RealtimeHandler {
|
||||
return &RealtimeHandler{BaseHandler{App: server, DB: db}}
|
||||
func NewRealtimeHandler(server *core.AppServer, db *gorm.DB, userService *service.UserService) *RealtimeHandler {
|
||||
return &RealtimeHandler{BaseHandler: BaseHandler{App: server, DB: db}, userService: userService}
|
||||
}
|
||||
|
||||
func (h *RealtimeHandler) Connection(c *gin.Context) {
|
||||
@@ -126,3 +136,73 @@ func sendError(ws *websocket.Conn, message string) {
|
||||
logger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI 实时语音对话,一次性对话
|
||||
func (h *RealtimeHandler) VoiceChat(c *gin.Context) {
|
||||
var apiKey model.ApiKey
|
||||
err := h.DB.Session(&gorm.Session{}).Where("type", "realtime").Where("enabled", true).First(&apiKey).Error
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("error with fetch OpenAI API KEY:%v", err))
|
||||
}
|
||||
|
||||
var response utils.OpenAIResponse
|
||||
client := req.C()
|
||||
if len(apiKey.ProxyURL) > 5 {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
}
|
||||
apiURL := fmt.Sprintf("%s/v1/chat/completions", apiKey.ApiURL)
|
||||
logger.Infof("Sending %s request, API KEY:%s, PROXY: %s, Model: %s", apiKey.ApiURL, apiURL, apiKey.ProxyURL, "advanced-voice")
|
||||
r, err := client.R().SetHeader("Body-Type", "application/json").
|
||||
SetHeader("Authorization", "Bearer "+apiKey.Value).
|
||||
SetBody(types.ApiRequest{
|
||||
Model: "advanced-voice",
|
||||
Temperature: 0.9,
|
||||
MaxTokens: 1024,
|
||||
Stream: false,
|
||||
Messages: []interface{}{types.Message{
|
||||
Role: "user",
|
||||
Content: "实时语音通话",
|
||||
}},
|
||||
}).Post(apiURL)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败:%v", err))
|
||||
return
|
||||
}
|
||||
|
||||
if r.IsErrorState() {
|
||||
resp.ERROR(c, fmt.Sprintf("请求 OpenAI API失败:%v", r.Status))
|
||||
return
|
||||
}
|
||||
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
err = json.Unmarshal(body, &response)
|
||||
if err != nil {
|
||||
resp.ERROR(c, fmt.Sprintf("解析API数据失败:%v, %s", err, string(body)))
|
||||
}
|
||||
|
||||
// 更新 API KEY 的最后使用时间
|
||||
h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
|
||||
|
||||
// 扣减算力
|
||||
userId := h.GetLoginUserId(c)
|
||||
err = h.userService.DecreasePower(int(userId), h.App.SysConfig.AdvanceVoicePower, model.PowerLog{
|
||||
Type: types.PowerConsume,
|
||||
Model: "advanced-voice",
|
||||
Remark: "实时语音通话",
|
||||
})
|
||||
if err != nil {
|
||||
resp.ERROR(c, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
logger.Infof("Response: %v", response.Choices[0].Message.Content)
|
||||
|
||||
// 提取链接
|
||||
re := regexp.MustCompile(`\[(.*?)\]\((.*?)\)`)
|
||||
links := re.FindAllStringSubmatch(response.Choices[0].Message.Content, -1)
|
||||
var url = ""
|
||||
if len(links) > 0 {
|
||||
url = links[0][2]
|
||||
}
|
||||
resp.SUCCESS(c, url)
|
||||
}
|
||||
|
||||
@@ -215,8 +215,8 @@ func (h *SunoHandler) Remove(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
// 只有失败,或者超时的任务才能删除
|
||||
if job.Progress != service.FailTaskProgress || time.Now().Before(job.CreatedAt.Add(time.Minute*10)) {
|
||||
// 只有失败或者已完成的任务可以删除
|
||||
if !(job.Progress == service.FailTaskProgress || job.Progress == 100) {
|
||||
resp.ERROR(c, "只有失败和超时(10分钟)的任务才能删除!")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -350,6 +350,7 @@ func main() {
|
||||
group.POST("create", h.Create)
|
||||
group.POST("set", h.Set)
|
||||
group.GET("remove", h.Remove)
|
||||
group.POST("export", h.Export)
|
||||
}),
|
||||
fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
|
||||
group := s.Engine.Group("/api/admin/dashboard/")
|
||||
@@ -564,6 +565,7 @@ func main() {
|
||||
fx.Provide(handler.NewRealtimeHandler),
|
||||
fx.Invoke(func(s *core.AppServer, h *handler.RealtimeHandler) {
|
||||
s.Engine.Any("/api/realtime", h.Connection)
|
||||
s.Engine.POST("/api/realtime/voice", h.VoiceChat)
|
||||
}),
|
||||
)
|
||||
// 启动应用程序
|
||||
|
||||
@@ -130,10 +130,13 @@ type LumaRespVo struct {
|
||||
Id string `json:"id"`
|
||||
Prompt string `json:"prompt"`
|
||||
State string `json:"state"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
QueueState interface{} `json:"queue_state"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
Video interface{} `json:"video"`
|
||||
VideoRaw interface{} `json:"video_raw"`
|
||||
Liked interface{} `json:"liked"`
|
||||
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||||
Thumbnail interface{} `json:"thumbnail"`
|
||||
Channel string `json:"channel,omitempty"`
|
||||
}
|
||||
|
||||
@@ -234,7 +237,7 @@ func (s *Service) DownloadFiles() {
|
||||
continue
|
||||
}
|
||||
}
|
||||
logger.Info("download no water video success: %s", videoURL)
|
||||
logger.Infof("download no water video success: %s", videoURL)
|
||||
v.VideoURL = videoURL
|
||||
v.Progress = 100
|
||||
s.db.Updates(&v)
|
||||
@@ -275,6 +278,7 @@ func (s *Service) SyncTaskProgress() {
|
||||
"water_url": task.Video.Url,
|
||||
"raw_data": utils.JsonEncode(task),
|
||||
"prompt_ext": task.Prompt,
|
||||
"cover_url": task.Thumbnail.Url,
|
||||
}
|
||||
if task.Video.DownloadUrl != "" {
|
||||
data["video_url"] = task.Video.DownloadUrl
|
||||
@@ -315,11 +319,28 @@ type LumaTaskVo struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
Thumbnail string `json:"thumbnail"`
|
||||
DownloadUrl string `json:"download_url"`
|
||||
} `json:"video"`
|
||||
Prompt string `json:"prompt"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
EstimateWaitSeconds interface{} `json:"estimate_wait_seconds"`
|
||||
Prompt string `json:"prompt"`
|
||||
UserId string `json:"user_id"`
|
||||
BatchId string `json:"batch_id"`
|
||||
Thumbnail struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"thumbnail"`
|
||||
VideoRaw struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"video_raw"`
|
||||
CreatedAt string `json:"created_at"`
|
||||
LastFrame struct {
|
||||
Url string `json:"url"`
|
||||
Width int `json:"width"`
|
||||
Height int `json:"height"`
|
||||
} `json:"last_frame"`
|
||||
}
|
||||
|
||||
func (s *Service) QueryLumaTask(taskId string, channel string) (LumaTaskVo, error) {
|
||||
|
||||
@@ -12,11 +12,12 @@ import (
|
||||
"fmt"
|
||||
"geekai/core/types"
|
||||
"geekai/store/model"
|
||||
"io"
|
||||
"time"
|
||||
|
||||
"github.com/imroc/req/v3"
|
||||
"github.com/pkoukk/tiktoken-go"
|
||||
"gorm.io/gorm"
|
||||
"io"
|
||||
"time"
|
||||
)
|
||||
|
||||
func CalcTokens(text string, model string) (int, error) {
|
||||
@@ -33,7 +34,7 @@ func CalcTokens(text string, model string) (int, error) {
|
||||
return len(token), nil
|
||||
}
|
||||
|
||||
type apiRes struct {
|
||||
type OpenAIResponse struct {
|
||||
Model string `json:"model"`
|
||||
Choices []struct {
|
||||
Index int `json:"index"`
|
||||
@@ -70,7 +71,7 @@ func SendOpenAIMessage(db *gorm.DB, messages []interface{}, modelId int) (string
|
||||
return "", fmt.Errorf("error with fetch OpenAI API KEY:%v", err)
|
||||
}
|
||||
|
||||
var response apiRes
|
||||
var response OpenAIResponse
|
||||
client := req.C()
|
||||
if len(apiKey.ProxyURL) > 5 {
|
||||
client.SetProxyURL(apiKey.ApiURL)
|
||||
|
||||
Reference in New Issue
Block a user