mirror of
https://github.com/linux-do/new-api.git
synced 2025-11-09 15:43:41 +08:00
add epay
This commit is contained in:
@@ -94,6 +94,23 @@ func SearchUserLogs(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
func GetLogByKey(c *gin.Context) {
|
||||
key := c.Query("key")
|
||||
logs, err := model.GetLogByKey(key)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
}
|
||||
|
||||
func GetLogsStat(c *gin.Context) {
|
||||
logType, _ := strconv.Atoi(c.Query("type"))
|
||||
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)
|
||||
|
||||
133
controller/midjourney.go
Normal file
133
controller/midjourney.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
func UpdateMidjourneyTask() {
|
||||
imageModel := "midjourney"
|
||||
for {
|
||||
time.Sleep(time.Duration(15) * time.Second)
|
||||
tasks := model.GetAllUnFinishTasks()
|
||||
if len(tasks) != 0 {
|
||||
//log.Printf("UpdateMidjourneyTask: %v", time.Now())
|
||||
ids := make([]string, 0)
|
||||
for _, task := range tasks {
|
||||
ids = append(ids, task.MjId)
|
||||
}
|
||||
requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition"
|
||||
requestBody := map[string]interface{}{
|
||||
"ids": ids,
|
||||
}
|
||||
jsonStr, err := json.Marshal(requestBody)
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask: %v", err)
|
||||
}
|
||||
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask: %v", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
var response []Midjourney
|
||||
err = json.NewDecoder(resp.Body).Decode(&response)
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTask: %v", err)
|
||||
}
|
||||
for _, responseItem := range response {
|
||||
var midjourneyTask *model.Midjourney
|
||||
for _, mj := range tasks {
|
||||
mj.MjId = responseItem.MjId
|
||||
midjourneyTask = model.GetMjByuId(mj.Id)
|
||||
}
|
||||
if midjourneyTask != nil {
|
||||
midjourneyTask.Code = 1
|
||||
midjourneyTask.Progress = responseItem.Progress
|
||||
midjourneyTask.PromptEn = responseItem.PromptEn
|
||||
midjourneyTask.State = responseItem.State
|
||||
midjourneyTask.SubmitTime = responseItem.SubmitTime
|
||||
midjourneyTask.StartTime = responseItem.StartTime
|
||||
midjourneyTask.FinishTime = responseItem.FinishTime
|
||||
midjourneyTask.ImageUrl = responseItem.ImageUrl
|
||||
midjourneyTask.Status = responseItem.Status
|
||||
midjourneyTask.FailReason = responseItem.FailReason
|
||||
if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" {
|
||||
log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason)
|
||||
midjourneyTask.Progress = "100%"
|
||||
err = model.CacheUpdateUserQuota(midjourneyTask.UserId)
|
||||
if err != nil {
|
||||
log.Println("error update user quota cache: " + err.Error())
|
||||
} else {
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio("default")
|
||||
ratio := modelRatio * groupRatio
|
||||
quota := int(ratio * 1 * 1000)
|
||||
if quota != 0 {
|
||||
err := model.IncreaseUserQuota(midjourneyTask.UserId, quota)
|
||||
if err != nil {
|
||||
log.Println("fail to increase user quota")
|
||||
}
|
||||
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota))
|
||||
model.RecordLog(midjourneyTask.UserId, 1, logContent)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
err = midjourneyTask.Update()
|
||||
if err != nil {
|
||||
log.Printf("UpdateMidjourneyTaskFail: %v", err)
|
||||
}
|
||||
log.Printf("UpdateMidjourneyTask: %v", midjourneyTask)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func GetAllMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
}
|
||||
|
||||
func GetUserMidjourney(c *gin.Context) {
|
||||
p, _ := strconv.Atoi(c.Query("p"))
|
||||
if p < 0 {
|
||||
p = 0
|
||||
}
|
||||
userId := c.GetInt("id")
|
||||
log.Printf("userId = %d \n", userId)
|
||||
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
|
||||
if logs == nil {
|
||||
logs = make([]*model.Midjourney, 0)
|
||||
}
|
||||
c.JSON(200, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": logs,
|
||||
})
|
||||
}
|
||||
@@ -31,6 +31,9 @@ func GetStatus(c *gin.Context) {
|
||||
"chat_link": common.ChatLink,
|
||||
"quota_per_unit": common.QuotaPerUnit,
|
||||
"display_in_currency": common.DisplayInCurrencyEnabled,
|
||||
"normal_price": common.NormalPrice,
|
||||
"stable_price": common.StablePrice,
|
||||
"base_price": common.BasePrice,
|
||||
},
|
||||
})
|
||||
return
|
||||
@@ -58,6 +61,17 @@ func GetAbout(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
func GetMidjourney(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": common.OptionMap["Midjourney"],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func GetHomePageContent(c *gin.Context) {
|
||||
common.OptionMapRWMutex.RLock()
|
||||
defer common.OptionMapRWMutex.RUnlock()
|
||||
|
||||
@@ -137,7 +137,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
|
||||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
388
controller/relay-mj.go
Normal file
388
controller/relay-mj.go
Normal file
@@ -0,0 +1,388 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type Midjourney struct {
|
||||
MjId string `json:"id"`
|
||||
Action string `json:"action"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"promptEn"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submitTime"`
|
||||
StartTime int64 `json:"startTime"`
|
||||
FinishTime int64 `json:"finishTime"`
|
||||
ImageUrl string `json:"imageUrl"`
|
||||
Status string `json:"status"`
|
||||
Progress string `json:"progress"`
|
||||
FailReason string `json:"failReason"`
|
||||
}
|
||||
|
||||
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
|
||||
var midjRequest Midjourney
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask := model.GetByMJId(midjRequest.MjId)
|
||||
if midjourneyTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "midjourney_task_not_found",
|
||||
Properties: nil,
|
||||
Result: "",
|
||||
}
|
||||
}
|
||||
midjourneyTask.Progress = midjRequest.Progress
|
||||
midjourneyTask.PromptEn = midjRequest.PromptEn
|
||||
midjourneyTask.State = midjRequest.State
|
||||
midjourneyTask.SubmitTime = midjRequest.SubmitTime
|
||||
midjourneyTask.StartTime = midjRequest.StartTime
|
||||
midjourneyTask.FinishTime = midjRequest.FinishTime
|
||||
midjourneyTask.ImageUrl = midjRequest.ImageUrl
|
||||
midjourneyTask.Status = midjRequest.Status
|
||||
midjourneyTask.FailReason = midjRequest.FailReason
|
||||
err = midjourneyTask.Update()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "update_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
taskId := c.Param("id")
|
||||
originTask := model.GetByMJId(taskId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
}
|
||||
var midjourneyTask Midjourney
|
||||
midjourneyTask.MjId = originTask.MjId
|
||||
midjourneyTask.Progress = originTask.Progress
|
||||
midjourneyTask.PromptEn = originTask.PromptEn
|
||||
midjourneyTask.State = originTask.State
|
||||
midjourneyTask.SubmitTime = originTask.SubmitTime
|
||||
midjourneyTask.StartTime = originTask.StartTime
|
||||
midjourneyTask.FinishTime = originTask.FinishTime
|
||||
midjourneyTask.ImageUrl = originTask.ImageUrl
|
||||
midjourneyTask.Status = originTask.Status
|
||||
midjourneyTask.FailReason = originTask.FailReason
|
||||
midjourneyTask.Action = originTask.Action
|
||||
midjourneyTask.Description = originTask.Description
|
||||
midjourneyTask.Prompt = originTask.Prompt
|
||||
jsonMap, err := json.Marshal(midjourneyTask)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
|
||||
imageModel := "midjourney"
|
||||
|
||||
tokenId := c.GetInt("token_id")
|
||||
channelType := c.GetInt("channel")
|
||||
userId := c.GetInt("id")
|
||||
consumeQuota := c.GetBool("consume_quota")
|
||||
group := c.GetString("group")
|
||||
|
||||
var midjRequest MidjourneyRequest
|
||||
if consumeQuota {
|
||||
err := common.UnmarshalBodyReusable(c, &midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "bind_request_body_failed",
|
||||
}
|
||||
}
|
||||
}
|
||||
if relayMode == RelayModeMidjourneyImagine {
|
||||
if midjRequest.Prompt == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "prompt_is_required",
|
||||
}
|
||||
}
|
||||
midjRequest.Action = "IMAGINE"
|
||||
} else if midjRequest.TaskId != "" {
|
||||
originTask := model.GetByMJId(midjRequest.TaskId)
|
||||
if originTask == nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_no_found",
|
||||
}
|
||||
} else if originTask.Action == "UPSCALE" {
|
||||
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "upscale_task_can_not_be_change",
|
||||
}
|
||||
} else if originTask.Status != "SUCCESS" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "task_status_is_not_success",
|
||||
}
|
||||
|
||||
}
|
||||
midjRequest.Prompt = originTask.Prompt
|
||||
} else if relayMode == RelayModeMidjourneyChange {
|
||||
if midjRequest.TaskId == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "taskId_is_required",
|
||||
}
|
||||
} else if midjRequest.Action == "" {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "action_is_required",
|
||||
}
|
||||
} else if midjRequest.Index == 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "index_can_only_be_1_2_3_4",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// map model name
|
||||
modelMapping := c.GetString("model_mapping")
|
||||
isModelMapped := false
|
||||
if modelMapping != "" {
|
||||
modelMap := make(map[string]string)
|
||||
err := json.Unmarshal([]byte(modelMapping), &modelMap)
|
||||
if err != nil {
|
||||
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_model_mapping_failed",
|
||||
}
|
||||
}
|
||||
if modelMap[imageModel] != "" {
|
||||
imageModel = modelMap[imageModel]
|
||||
isModelMapped = true
|
||||
}
|
||||
}
|
||||
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
|
||||
if c.GetString("base_url") != "" {
|
||||
baseURL = c.GetString("base_url")
|
||||
}
|
||||
|
||||
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
|
||||
|
||||
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
|
||||
|
||||
var requestBody io.Reader
|
||||
if isModelMapped {
|
||||
jsonStr, err := json.Marshal(midjRequest)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "marshal_text_request_failed",
|
||||
}
|
||||
}
|
||||
requestBody = bytes.NewBuffer(jsonStr)
|
||||
} else {
|
||||
requestBody = c.Request.Body
|
||||
}
|
||||
|
||||
modelRatio := common.GetModelRatio(imageModel)
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
|
||||
sizeRatio := 1.0
|
||||
|
||||
quota := int(ratio * sizeRatio * 1000)
|
||||
|
||||
if consumeQuota && userQuota-quota < 0 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "quota_not_enough",
|
||||
}
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "create_request_failed",
|
||||
}
|
||||
}
|
||||
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
|
||||
|
||||
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
|
||||
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
|
||||
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
|
||||
// print request header
|
||||
log.Printf("request header: %s", req.Header)
|
||||
log.Printf("request body: %s", midjRequest.Prompt)
|
||||
|
||||
resp, err := httpClient.Do(req)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "do_request_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = req.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
err = c.Request.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_request_body_failed",
|
||||
}
|
||||
}
|
||||
var midjResponse MidjourneyResponse
|
||||
|
||||
defer func() {
|
||||
if consumeQuota {
|
||||
err := model.PostConsumeTokenQuota(tokenId, quota)
|
||||
if err != nil {
|
||||
common.SysError("error consuming token remain quota: " + err.Error())
|
||||
}
|
||||
err = model.CacheUpdateUserQuota(userId)
|
||||
if err != nil {
|
||||
common.SysError("error update user quota cache: " + err.Error())
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
//if consumeQuota {
|
||||
//
|
||||
//}
|
||||
responseBody, err := io.ReadAll(resp.Body)
|
||||
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "read_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
|
||||
err = json.Unmarshal(responseBody, &midjResponse)
|
||||
log.Printf("responseBody: %s", string(responseBody))
|
||||
log.Printf("midjResponse: %v", midjResponse)
|
||||
if resp.StatusCode != 200 {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "fail_to_fetch_midjourney",
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "unmarshal_response_body_failed",
|
||||
}
|
||||
}
|
||||
if midjResponse.Code == 24 || midjResponse.Code == 21 || midjResponse.Code == 4 {
|
||||
consumeQuota = false
|
||||
}
|
||||
|
||||
midjourneyTask := &model.Midjourney{
|
||||
UserId: userId,
|
||||
Code: midjResponse.Code,
|
||||
Action: midjRequest.Action,
|
||||
MjId: midjResponse.Result,
|
||||
Prompt: midjRequest.Prompt,
|
||||
PromptEn: "",
|
||||
Description: midjResponse.Description,
|
||||
State: "",
|
||||
SubmitTime: 0,
|
||||
StartTime: 0,
|
||||
FinishTime: 0,
|
||||
ImageUrl: "",
|
||||
Status: "",
|
||||
Progress: "0%",
|
||||
FailReason: "",
|
||||
}
|
||||
if midjResponse.Code == 4 {
|
||||
midjourneyTask.FailReason = midjResponse.Description
|
||||
}
|
||||
err = midjourneyTask.Insert()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "insert_midjourney_task_failed",
|
||||
}
|
||||
}
|
||||
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
|
||||
|
||||
for k, v := range resp.Header {
|
||||
c.Writer.Header().Set(k, v[0])
|
||||
}
|
||||
c.Writer.WriteHeader(resp.StatusCode)
|
||||
|
||||
_, err = io.Copy(c.Writer, resp.Body)
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "copy_response_body_failed",
|
||||
}
|
||||
}
|
||||
err = resp.Body.Close()
|
||||
if err != nil {
|
||||
return &MidjourneyResponse{
|
||||
Code: 4,
|
||||
Description: "close_response_body_failed",
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -95,6 +95,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
case common.ChannelTypeZhipu:
|
||||
apiType = APITypeZhipu
|
||||
}
|
||||
isStable := c.GetBool("stable")
|
||||
|
||||
//if common.NormalPrice == -1 && strings.HasPrefix(textRequest.Model, "gpt-4") {
|
||||
// nowUser, err := model.GetUserById(userId, false)
|
||||
// if err != nil {
|
||||
// return errorWrapper(err, "get_user_info_failed", http.StatusInternalServerError)
|
||||
// }
|
||||
// if nowUser.StableMode {
|
||||
// group = "svip"
|
||||
// isStable = true
|
||||
// ////stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
|
||||
// //userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
|
||||
// //if userMaxPrice < common.StablePrice {
|
||||
// // return errorWrapper(errors.New("当前低价通道不可用,稳定渠道价格为"+strconv.FormatFloat(common.StablePrice, 'f', -1, 64)+"R/刀"), "当前低价通道不可用", http.StatusInternalServerError)
|
||||
// //}
|
||||
// //
|
||||
// ////ratio = stableRatio * groupRatio
|
||||
// //channel, err := model.CacheGetRandomSatisfiedChannel("svip", textRequest.Model)
|
||||
// //if err != nil {
|
||||
// // message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", "svip", textRequest.Model)
|
||||
// // return errorWrapper(errors.New(message), "no_available_channel", http.StatusInternalServerError)
|
||||
// //}
|
||||
// //channelType = channel.Type
|
||||
// } else {
|
||||
// return errorWrapper(errors.New("当前低价通道不可用,请稍后再试,或者在后台开启稳定模式"), "当前低价通道不可用", http.StatusInternalServerError)
|
||||
// }
|
||||
//}
|
||||
baseURL := common.ChannelBaseURLs[channelType]
|
||||
requestURL := c.Request.URL.String()
|
||||
if c.GetString("base_url") != "" {
|
||||
@@ -168,11 +195,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
if textRequest.MaxTokens != 0 {
|
||||
preConsumedTokens = promptTokens + textRequest.MaxTokens
|
||||
}
|
||||
//stableRatio := common.GetStableRatio(textRequest.Model)
|
||||
modelRatio := common.GetModelRatio(textRequest.Model)
|
||||
stableRatio := modelRatio
|
||||
groupRatio := common.GetGroupRatio(group)
|
||||
ratio := modelRatio * groupRatio
|
||||
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
|
||||
userQuota, err := model.CacheGetUserQuota(userId)
|
||||
if isStable {
|
||||
stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
|
||||
ratio = stableRatio * groupRatio
|
||||
}
|
||||
if err != nil {
|
||||
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
|
||||
}
|
||||
@@ -301,6 +334,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
// we cannot just return, because we may have to return the pre-consumed quota
|
||||
quota = 0
|
||||
}
|
||||
//if strings.HasPrefix(textRequest.Model, "gpt-4") {
|
||||
// if quota < 5000 && quota != 0 {
|
||||
// quota = 5000
|
||||
// }
|
||||
//}
|
||||
quotaDelta := quota - preConsumedQuota
|
||||
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
|
||||
if err != nil {
|
||||
@@ -312,8 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
|
||||
}
|
||||
if quota != 0 {
|
||||
tokenName := c.GetString("token_name")
|
||||
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
|
||||
var logContent string
|
||||
if isStable {
|
||||
logContent = fmt.Sprintf("(稳定模式)模型倍率 %.2f,分组倍率 %.2f", stableRatio, groupRatio)
|
||||
} else {
|
||||
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
|
||||
}
|
||||
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
|
||||
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
|
||||
channelId := c.GetInt("channel_id")
|
||||
model.UpdateChannelUsedQuota(channelId, quota)
|
||||
|
||||
@@ -2,6 +2,7 @@ package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"strconv"
|
||||
@@ -24,6 +25,10 @@ const (
|
||||
RelayModeModerations
|
||||
RelayModeImagesGenerations
|
||||
RelayModeEdits
|
||||
RelayModeMidjourneyImagine
|
||||
RelayModeMidjourneyChange
|
||||
RelayModeMidjourneyNotify
|
||||
RelayModeMidjourneyTaskFetch
|
||||
)
|
||||
|
||||
// https://platform.openai.com/docs/api-reference/chat
|
||||
@@ -128,6 +133,23 @@ type CompletionsStreamResponse struct {
|
||||
} `json:"choices"`
|
||||
}
|
||||
|
||||
type MidjourneyRequest struct {
|
||||
Prompt string `json:"prompt"`
|
||||
NotifyHook string `json:"notifyHook"`
|
||||
Action string `json:"action"`
|
||||
Index int `json:"index"`
|
||||
State string `json:"state"`
|
||||
TaskId string `json:"taskId"`
|
||||
Base64Array []string `json:"base64Array"`
|
||||
}
|
||||
|
||||
type MidjourneyResponse struct {
|
||||
Code int `json:"code"`
|
||||
Description string `json:"description"`
|
||||
Properties interface{} `json:"properties"`
|
||||
Result string `json:"result"`
|
||||
}
|
||||
|
||||
func Relay(c *gin.Context) {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
|
||||
@@ -179,6 +201,54 @@ func Relay(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func RelayMidjourney(c *gin.Context) {
|
||||
relayMode := RelayModeUnknown
|
||||
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
|
||||
relayMode = RelayModeMidjourneyImagine
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
|
||||
relayMode = RelayModeMidjourneyNotify
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
|
||||
relayMode = RelayModeMidjourneyChange
|
||||
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
|
||||
relayMode = RelayModeMidjourneyTaskFetch
|
||||
}
|
||||
var err *MidjourneyResponse
|
||||
switch relayMode {
|
||||
case RelayModeMidjourneyNotify:
|
||||
err = relayMidjourneyNotify(c)
|
||||
case RelayModeMidjourneyTaskFetch:
|
||||
err = relayMidjourneyTask(c, relayMode)
|
||||
default:
|
||||
err = relayMidjourneySubmit(c, relayMode)
|
||||
}
|
||||
//err = relayMidjourneySubmit(c, relayMode)
|
||||
log.Println(err)
|
||||
if err != nil {
|
||||
retryTimesStr := c.Query("retry")
|
||||
retryTimes, _ := strconv.Atoi(retryTimesStr)
|
||||
if retryTimesStr == "" {
|
||||
retryTimes = common.RetryTimes
|
||||
}
|
||||
if retryTimes > 0 {
|
||||
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
|
||||
} else {
|
||||
if err.Code == 30 {
|
||||
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
|
||||
}
|
||||
c.JSON(400, gin.H{
|
||||
"error": err.Result,
|
||||
})
|
||||
}
|
||||
channelId := c.GetInt("channel_id")
|
||||
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
|
||||
//if shouldDisableChannel(&err.OpenAIError) {
|
||||
// channelId := c.GetInt("channel_id")
|
||||
// channelName := c.GetString("channel_name")
|
||||
// disableChannel(channelId, channelName, err.Result)
|
||||
//}
|
||||
}
|
||||
}
|
||||
|
||||
func RelayNotImplemented(c *gin.Context) {
|
||||
err := OpenAIError{
|
||||
Message: "API not implemented",
|
||||
|
||||
173
controller/topup.go
Normal file
173
controller/topup.go
Normal file
@@ -0,0 +1,173 @@
|
||||
package controller
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/samber/lo"
|
||||
epay "github.com/star-horizon/go-epay"
|
||||
"log"
|
||||
"net/url"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
"strconv"
|
||||
"time"
|
||||
)
|
||||
|
||||
type EpayRequest struct {
|
||||
Amount int `json:"amount"`
|
||||
PaymentMethod string `json:"payment_method"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
type AmountRequest struct {
|
||||
Amount int `json:"amount"`
|
||||
TopUpCode string `json:"top_up_code"`
|
||||
}
|
||||
|
||||
var client, _ = epay.NewClientWithUrl(&epay.Config{
|
||||
PartnerID: "1096",
|
||||
Key: "n08V9LpE8JffA3NPP893689u8p39NV9J",
|
||||
}, "https://api.lempay.org")
|
||||
|
||||
func GetAmount(id int, count float64, topUpCode string) float64 {
|
||||
amount := count * 1.5
|
||||
if topUpCode != "" {
|
||||
if topUpCode == "nekoapi" {
|
||||
if id == 89 {
|
||||
amount = count * 1
|
||||
} else if id == 98 || id == 105 || id == 107 {
|
||||
amount = count * 1.2
|
||||
} else if id == 1 {
|
||||
amount = count * 1
|
||||
}
|
||||
}
|
||||
}
|
||||
return amount
|
||||
}
|
||||
|
||||
func RequestEpay(c *gin.Context) {
|
||||
var req EpayRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
amount := GetAmount(id, float64(req.Amount), req.TopUpCode)
|
||||
if id != 1 {
|
||||
if req.Amount < 10 {
|
||||
c.JSON(200, gin.H{"message": "最小充值10元", "data": amount, "count": 10})
|
||||
return
|
||||
}
|
||||
}
|
||||
if req.PaymentMethod == "zfb" {
|
||||
if amount > 400 {
|
||||
c.JSON(200, gin.H{"message": "支付宝最大充值400元", "data": amount, "count": 400})
|
||||
return
|
||||
}
|
||||
req.PaymentMethod = "alipay"
|
||||
}
|
||||
if req.PaymentMethod == "wx" {
|
||||
if amount > 600 {
|
||||
c.JSON(200, gin.H{"message": "微信最大充值600元", "data": amount, "count": 600})
|
||||
return
|
||||
}
|
||||
req.PaymentMethod = "wxpay"
|
||||
}
|
||||
|
||||
returnUrl, _ := url.Parse("https://nekoapi.com/log")
|
||||
notifyUrl, _ := url.Parse("https://nekoapi.com/api/user/epay/notify")
|
||||
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
|
||||
uri, params, err := client.Purchase(&epay.PurchaseArgs{
|
||||
Type: epay.PurchaseType(req.PaymentMethod),
|
||||
ServiceTradeNo: "A" + tradeNo,
|
||||
Name: "B" + tradeNo,
|
||||
Money: strconv.FormatFloat(amount*0.99, 'f', 2, 64),
|
||||
Device: epay.PC,
|
||||
NotifyUrl: notifyUrl,
|
||||
ReturnUrl: returnUrl,
|
||||
})
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
|
||||
return
|
||||
}
|
||||
topUp := &model.TopUp{
|
||||
UserId: id,
|
||||
Amount: req.Amount,
|
||||
Money: int(amount),
|
||||
TradeNo: "A" + tradeNo,
|
||||
CreateTime: time.Now().Unix(),
|
||||
Status: "pending",
|
||||
}
|
||||
err = topUp.Insert()
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
|
||||
}
|
||||
|
||||
func EpayNotify(c *gin.Context) {
|
||||
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
|
||||
r[t] = c.Request.URL.Query().Get(t)
|
||||
return r
|
||||
}, map[string]string{})
|
||||
verifyInfo, err := client.Verify(params)
|
||||
if err == nil && verifyInfo.VerifyStatus {
|
||||
_, err := c.Writer.Write([]byte("success"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
} else {
|
||||
_, err := c.Writer.Write([]byte("fail"))
|
||||
if err != nil {
|
||||
log.Println("易支付回调写入失败")
|
||||
}
|
||||
}
|
||||
|
||||
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
|
||||
log.Println(verifyInfo)
|
||||
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
|
||||
if topUp.Status == "pending" {
|
||||
topUp.Status = "success"
|
||||
err := topUp.Update()
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新订单失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
//user, _ := model.GetUserById(topUp.UserId, false)
|
||||
//user.Quota += topUp.Amount * 500000
|
||||
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
|
||||
if err != nil {
|
||||
log.Printf("易支付回调更新用户失败: %v", topUp)
|
||||
return
|
||||
}
|
||||
log.Printf("易支付回调更新用户成功 %v", topUp)
|
||||
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v", common.LogQuota(topUp.Amount*500000)))
|
||||
}
|
||||
} else {
|
||||
log.Printf("易支付异常回调: %v", verifyInfo)
|
||||
}
|
||||
}
|
||||
|
||||
func RequestAmount(c *gin.Context) {
|
||||
var req AmountRequest
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
|
||||
return
|
||||
}
|
||||
id := c.GetInt("id")
|
||||
if id != 1 {
|
||||
if req.Amount < 10 {
|
||||
c.JSON(200, gin.H{"message": "最小充值10刀", "data": GetAmount(id, 10, req.TopUpCode), "count": 10})
|
||||
return
|
||||
}
|
||||
if req.Amount > 400 {
|
||||
c.JSON(200, gin.H{"message": "最大充值400刀", "data": GetAmount(id, 400, req.TopUpCode), "count": 400})
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.JSON(200, gin.H{"message": "success", "data": GetAmount(id, float64(req.Amount), req.TopUpCode)})
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package controller
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/model"
|
||||
@@ -79,6 +80,8 @@ func setupLogin(user *model.User, c *gin.Context) {
|
||||
DisplayName: user.DisplayName,
|
||||
Role: user.Role,
|
||||
Status: user.Status,
|
||||
StableMode: user.StableMode,
|
||||
MaxPrice: user.MaxPrice,
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"message": "",
|
||||
@@ -158,6 +161,8 @@ func Register(c *gin.Context) {
|
||||
Password: user.Password,
|
||||
DisplayName: user.Username,
|
||||
InviterId: inviterId,
|
||||
StableMode: user.StableMode,
|
||||
MaxPrice: user.MaxPrice,
|
||||
}
|
||||
if common.EmailVerificationEnabled {
|
||||
cleanUser.Email = user.Email
|
||||
@@ -420,6 +425,8 @@ func UpdateSelf(c *gin.Context) {
|
||||
Username: user.Username,
|
||||
Password: user.Password,
|
||||
DisplayName: user.DisplayName,
|
||||
StableMode: user.StableMode,
|
||||
MaxPrice: user.MaxPrice,
|
||||
}
|
||||
if user.Password == "$I_LOVE_U" {
|
||||
user.Password = "" // rollback to what it should be
|
||||
@@ -741,3 +748,52 @@ func TopUp(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
type StableModeRequest struct {
|
||||
StableMode bool `json:"stableMode"`
|
||||
MaxPrice string `json:"maxPrice"`
|
||||
}
|
||||
|
||||
func SetTableMode(c *gin.Context) {
|
||||
req := &StableModeRequest{}
|
||||
err := c.ShouldBindJSON(&req)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
log.Println(req)
|
||||
id := c.GetInt("id")
|
||||
user := model.User{
|
||||
Id: id,
|
||||
}
|
||||
err = user.FillUserById()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
user.StableMode = req.StableMode
|
||||
if !req.StableMode {
|
||||
req.MaxPrice = "0"
|
||||
}
|
||||
user.MaxPrice = req.MaxPrice
|
||||
err = user.Update(false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": false,
|
||||
"message": err.Error(),
|
||||
})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"success": true,
|
||||
"message": "",
|
||||
"data": "设置成功",
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user