This commit is contained in:
CaIon
2023-08-14 22:16:32 +08:00
parent c134604cee
commit 8f2119e410
33 changed files with 3224 additions and 1138 deletions

View File

@@ -94,6 +94,23 @@ func SearchUserLogs(c *gin.Context) {
})
}
func GetLogByKey(c *gin.Context) {
key := c.Query("key")
logs, err := model.GetLogByKey(key)
if err != nil {
c.JSON(200, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetLogsStat(c *gin.Context) {
logType, _ := strconv.Atoi(c.Query("type"))
startTimestamp, _ := strconv.ParseInt(c.Query("start_timestamp"), 10, 64)

133
controller/midjourney.go Normal file
View File

@@ -0,0 +1,133 @@
package controller
import (
"bytes"
"encoding/json"
"fmt"
"github.com/gin-gonic/gin"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
func UpdateMidjourneyTask() {
imageModel := "midjourney"
for {
time.Sleep(time.Duration(15) * time.Second)
tasks := model.GetAllUnFinishTasks()
if len(tasks) != 0 {
//log.Printf("UpdateMidjourneyTask: %v", time.Now())
ids := make([]string, 0)
for _, task := range tasks {
ids = append(ids, task.MjId)
}
requestUrl := "http://107.173.171.147:8080/mj/task/list-by-condition"
requestBody := map[string]interface{}{
"ids": ids,
}
jsonStr, err := json.Marshal(requestBody)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
req, err := http.NewRequest("POST", requestUrl, bytes.NewBuffer(jsonStr))
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("mj-api-secret", "uhiftyuwadbkjshbiklahcuitguasguzhxliawodawdu")
resp, err := httpClient.Do(req)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
defer resp.Body.Close()
var response []Midjourney
err = json.NewDecoder(resp.Body).Decode(&response)
if err != nil {
log.Printf("UpdateMidjourneyTask: %v", err)
}
for _, responseItem := range response {
var midjourneyTask *model.Midjourney
for _, mj := range tasks {
mj.MjId = responseItem.MjId
midjourneyTask = model.GetMjByuId(mj.Id)
}
if midjourneyTask != nil {
midjourneyTask.Code = 1
midjourneyTask.Progress = responseItem.Progress
midjourneyTask.PromptEn = responseItem.PromptEn
midjourneyTask.State = responseItem.State
midjourneyTask.SubmitTime = responseItem.SubmitTime
midjourneyTask.StartTime = responseItem.StartTime
midjourneyTask.FinishTime = responseItem.FinishTime
midjourneyTask.ImageUrl = responseItem.ImageUrl
midjourneyTask.Status = responseItem.Status
midjourneyTask.FailReason = responseItem.FailReason
if midjourneyTask.Progress != "100%" && responseItem.FailReason != "" {
log.Println(midjourneyTask.MjId + " 构建失败," + midjourneyTask.FailReason)
midjourneyTask.Progress = "100%"
err = model.CacheUpdateUserQuota(midjourneyTask.UserId)
if err != nil {
log.Println("error update user quota cache: " + err.Error())
} else {
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio("default")
ratio := modelRatio * groupRatio
quota := int(ratio * 1 * 1000)
if quota != 0 {
err := model.IncreaseUserQuota(midjourneyTask.UserId, quota)
if err != nil {
log.Println("fail to increase user quota")
}
logContent := fmt.Sprintf("%s 构图失败,补偿 %s", midjourneyTask.MjId, common.LogQuota(quota))
model.RecordLog(midjourneyTask.UserId, 1, logContent)
}
}
}
err = midjourneyTask.Update()
if err != nil {
log.Printf("UpdateMidjourneyTaskFail: %v", err)
}
log.Printf("UpdateMidjourneyTask: %v", midjourneyTask)
}
}
}
}
}
func GetAllMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
logs := model.GetAllTasks(p*common.ItemsPerPage, common.ItemsPerPage)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}
func GetUserMidjourney(c *gin.Context) {
p, _ := strconv.Atoi(c.Query("p"))
if p < 0 {
p = 0
}
userId := c.GetInt("id")
log.Printf("userId = %d \n", userId)
logs := model.GetAllUserTask(userId, p*common.ItemsPerPage, common.ItemsPerPage)
if logs == nil {
logs = make([]*model.Midjourney, 0)
}
c.JSON(200, gin.H{
"success": true,
"message": "",
"data": logs,
})
}

View File

@@ -31,6 +31,9 @@ func GetStatus(c *gin.Context) {
"chat_link": common.ChatLink,
"quota_per_unit": common.QuotaPerUnit,
"display_in_currency": common.DisplayInCurrencyEnabled,
"normal_price": common.NormalPrice,
"stable_price": common.StablePrice,
"base_price": common.BasePrice,
},
})
return
@@ -58,6 +61,17 @@ func GetAbout(c *gin.Context) {
return
}
func GetMidjourney(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": common.OptionMap["Midjourney"],
})
return
}
func GetHomePageContent(c *gin.Context) {
common.OptionMapRWMutex.RLock()
defer common.OptionMapRWMutex.RUnlock()

View File

@@ -137,7 +137,7 @@ func relayImageHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

388
controller/relay-mj.go Normal file
View File

@@ -0,0 +1,388 @@
package controller
import (
"bytes"
"encoding/json"
"fmt"
"io"
"log"
"net/http"
"one-api/common"
"one-api/model"
"strings"
"github.com/gin-gonic/gin"
)
type Midjourney struct {
MjId string `json:"id"`
Action string `json:"action"`
Prompt string `json:"prompt"`
PromptEn string `json:"promptEn"`
Description string `json:"description"`
State string `json:"state"`
SubmitTime int64 `json:"submitTime"`
StartTime int64 `json:"startTime"`
FinishTime int64 `json:"finishTime"`
ImageUrl string `json:"imageUrl"`
Status string `json:"status"`
Progress string `json:"progress"`
FailReason string `json:"failReason"`
}
func relayMidjourneyNotify(c *gin.Context) *MidjourneyResponse {
var midjRequest Midjourney
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
Properties: nil,
Result: "",
}
}
midjourneyTask := model.GetByMJId(midjRequest.MjId)
if midjourneyTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "midjourney_task_not_found",
Properties: nil,
Result: "",
}
}
midjourneyTask.Progress = midjRequest.Progress
midjourneyTask.PromptEn = midjRequest.PromptEn
midjourneyTask.State = midjRequest.State
midjourneyTask.SubmitTime = midjRequest.SubmitTime
midjourneyTask.StartTime = midjRequest.StartTime
midjourneyTask.FinishTime = midjRequest.FinishTime
midjourneyTask.ImageUrl = midjRequest.ImageUrl
midjourneyTask.Status = midjRequest.Status
midjourneyTask.FailReason = midjRequest.FailReason
err = midjourneyTask.Update()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "update_midjourney_task_failed",
}
}
return nil
}
func relayMidjourneyTask(c *gin.Context, relayMode int) *MidjourneyResponse {
taskId := c.Param("id")
originTask := model.GetByMJId(taskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
}
var midjourneyTask Midjourney
midjourneyTask.MjId = originTask.MjId
midjourneyTask.Progress = originTask.Progress
midjourneyTask.PromptEn = originTask.PromptEn
midjourneyTask.State = originTask.State
midjourneyTask.SubmitTime = originTask.SubmitTime
midjourneyTask.StartTime = originTask.StartTime
midjourneyTask.FinishTime = originTask.FinishTime
midjourneyTask.ImageUrl = originTask.ImageUrl
midjourneyTask.Status = originTask.Status
midjourneyTask.FailReason = originTask.FailReason
midjourneyTask.Action = originTask.Action
midjourneyTask.Description = originTask.Description
midjourneyTask.Prompt = originTask.Prompt
jsonMap, err := json.Marshal(midjourneyTask)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
_, err = io.Copy(c.Writer, bytes.NewBuffer(jsonMap))
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
return nil
}
func relayMidjourneySubmit(c *gin.Context, relayMode int) *MidjourneyResponse {
imageModel := "midjourney"
tokenId := c.GetInt("token_id")
channelType := c.GetInt("channel")
userId := c.GetInt("id")
consumeQuota := c.GetBool("consume_quota")
group := c.GetString("group")
var midjRequest MidjourneyRequest
if consumeQuota {
err := common.UnmarshalBodyReusable(c, &midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "bind_request_body_failed",
}
}
}
if relayMode == RelayModeMidjourneyImagine {
if midjRequest.Prompt == "" {
return &MidjourneyResponse{
Code: 4,
Description: "prompt_is_required",
}
}
midjRequest.Action = "IMAGINE"
} else if midjRequest.TaskId != "" {
originTask := model.GetByMJId(midjRequest.TaskId)
if originTask == nil {
return &MidjourneyResponse{
Code: 4,
Description: "task_no_found",
}
} else if originTask.Action == "UPSCALE" {
//return errorWrapper(errors.New("upscale task can not be change"), "request_params_error", http.StatusBadRequest).
return &MidjourneyResponse{
Code: 4,
Description: "upscale_task_can_not_be_change",
}
} else if originTask.Status != "SUCCESS" {
return &MidjourneyResponse{
Code: 4,
Description: "task_status_is_not_success",
}
}
midjRequest.Prompt = originTask.Prompt
} else if relayMode == RelayModeMidjourneyChange {
if midjRequest.TaskId == "" {
return &MidjourneyResponse{
Code: 4,
Description: "taskId_is_required",
}
} else if midjRequest.Action == "" {
return &MidjourneyResponse{
Code: 4,
Description: "action_is_required",
}
} else if midjRequest.Index == 0 {
return &MidjourneyResponse{
Code: 4,
Description: "index_can_only_be_1_2_3_4",
}
}
}
// map model name
modelMapping := c.GetString("model_mapping")
isModelMapped := false
if modelMapping != "" {
modelMap := make(map[string]string)
err := json.Unmarshal([]byte(modelMapping), &modelMap)
if err != nil {
//return errorWrapper(err, "unmarshal_model_mapping_failed", http.StatusInternalServerError)
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_model_mapping_failed",
}
}
if modelMap[imageModel] != "" {
imageModel = modelMap[imageModel]
isModelMapped = true
}
}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
baseURL = c.GetString("base_url")
}
//midjRequest.NotifyHook = "http://127.0.0.1:3000/mj/notify"
fullRequestURL := fmt.Sprintf("%s%s", baseURL, requestURL)
var requestBody io.Reader
if isModelMapped {
jsonStr, err := json.Marshal(midjRequest)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "marshal_text_request_failed",
}
}
requestBody = bytes.NewBuffer(jsonStr)
} else {
requestBody = c.Request.Body
}
modelRatio := common.GetModelRatio(imageModel)
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
userQuota, err := model.CacheGetUserQuota(userId)
sizeRatio := 1.0
quota := int(ratio * sizeRatio * 1000)
if consumeQuota && userQuota-quota < 0 {
return &MidjourneyResponse{
Code: 4,
Description: "quota_not_enough",
}
}
req, err := http.NewRequest(c.Request.Method, fullRequestURL, requestBody)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "create_request_failed",
}
}
//req.Header.Set("Authorization", c.Request.Header.Get("Authorization"))
req.Header.Set("Content-Type", c.Request.Header.Get("Content-Type"))
req.Header.Set("Accept", c.Request.Header.Get("Accept"))
req.Header.Set("mj-api-secret", strings.Split(c.Request.Header.Get("Authorization"), " ")[1])
// print request header
log.Printf("request header: %s", req.Header)
log.Printf("request body: %s", midjRequest.Prompt)
resp, err := httpClient.Do(req)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "do_request_failed",
}
}
err = req.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
err = c.Request.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_request_body_failed",
}
}
var midjResponse MidjourneyResponse
defer func() {
if consumeQuota {
err := model.PostConsumeTokenQuota(tokenId, quota)
if err != nil {
common.SysError("error consuming token remain quota: " + err.Error())
}
err = model.CacheUpdateUserQuota(userId)
if err != nil {
common.SysError("error update user quota cache: " + err.Error())
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, 0, 0, imageModel, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)
}
}
}()
//if consumeQuota {
//
//}
responseBody, err := io.ReadAll(resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "read_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
err = json.Unmarshal(responseBody, &midjResponse)
log.Printf("responseBody: %s", string(responseBody))
log.Printf("midjResponse: %v", midjResponse)
if resp.StatusCode != 200 {
return &MidjourneyResponse{
Code: 4,
Description: "fail_to_fetch_midjourney",
}
}
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "unmarshal_response_body_failed",
}
}
if midjResponse.Code == 24 || midjResponse.Code == 21 || midjResponse.Code == 4 {
consumeQuota = false
}
midjourneyTask := &model.Midjourney{
UserId: userId,
Code: midjResponse.Code,
Action: midjRequest.Action,
MjId: midjResponse.Result,
Prompt: midjRequest.Prompt,
PromptEn: "",
Description: midjResponse.Description,
State: "",
SubmitTime: 0,
StartTime: 0,
FinishTime: 0,
ImageUrl: "",
Status: "",
Progress: "0%",
FailReason: "",
}
if midjResponse.Code == 4 {
midjourneyTask.FailReason = midjResponse.Description
}
err = midjourneyTask.Insert()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "insert_midjourney_task_failed",
}
}
resp.Body = io.NopCloser(bytes.NewBuffer(responseBody))
for k, v := range resp.Header {
c.Writer.Header().Set(k, v[0])
}
c.Writer.WriteHeader(resp.StatusCode)
_, err = io.Copy(c.Writer, resp.Body)
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "copy_response_body_failed",
}
}
err = resp.Body.Close()
if err != nil {
return &MidjourneyResponse{
Code: 4,
Description: "close_response_body_failed",
}
}
return nil
}

View File

@@ -95,6 +95,33 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
case common.ChannelTypeZhipu:
apiType = APITypeZhipu
}
isStable := c.GetBool("stable")
//if common.NormalPrice == -1 && strings.HasPrefix(textRequest.Model, "gpt-4") {
// nowUser, err := model.GetUserById(userId, false)
// if err != nil {
// return errorWrapper(err, "get_user_info_failed", http.StatusInternalServerError)
// }
// if nowUser.StableMode {
// group = "svip"
// isStable = true
// ////stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
// //userMaxPrice, _ := strconv.ParseFloat(nowUser.MaxPrice, 64)
// //if userMaxPrice < common.StablePrice {
// // return errorWrapper(errors.New("当前低价通道不可用,稳定渠道价格为"+strconv.FormatFloat(common.StablePrice, 'f', -1, 64)+"R/刀"), "当前低价通道不可用", http.StatusInternalServerError)
// //}
// //
// ////ratio = stableRatio * groupRatio
// //channel, err := model.CacheGetRandomSatisfiedChannel("svip", textRequest.Model)
// //if err != nil {
// // message := fmt.Sprintf("当前分组 %s 下对于模型 %s 无可用渠道", "svip", textRequest.Model)
// // return errorWrapper(errors.New(message), "no_available_channel", http.StatusInternalServerError)
// //}
// //channelType = channel.Type
// } else {
// return errorWrapper(errors.New("当前低价通道不可用,请稍后再试,或者在后台开启稳定模式"), "当前低价通道不可用", http.StatusInternalServerError)
// }
//}
baseURL := common.ChannelBaseURLs[channelType]
requestURL := c.Request.URL.String()
if c.GetString("base_url") != "" {
@@ -168,11 +195,17 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
if textRequest.MaxTokens != 0 {
preConsumedTokens = promptTokens + textRequest.MaxTokens
}
//stableRatio := common.GetStableRatio(textRequest.Model)
modelRatio := common.GetModelRatio(textRequest.Model)
stableRatio := modelRatio
groupRatio := common.GetGroupRatio(group)
ratio := modelRatio * groupRatio
preConsumedQuota := int(float64(preConsumedTokens) * ratio)
userQuota, err := model.CacheGetUserQuota(userId)
if isStable {
stableRatio = (common.StablePrice / common.BasePrice) * modelRatio
ratio = stableRatio * groupRatio
}
if err != nil {
return errorWrapper(err, "get_user_quota_failed", http.StatusInternalServerError)
}
@@ -301,6 +334,11 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
// we cannot just return, because we may have to return the pre-consumed quota
quota = 0
}
//if strings.HasPrefix(textRequest.Model, "gpt-4") {
// if quota < 5000 && quota != 0 {
// quota = 5000
// }
//}
quotaDelta := quota - preConsumedQuota
err := model.PostConsumeTokenQuota(tokenId, quotaDelta)
if err != nil {
@@ -312,8 +350,13 @@ func relayTextHelper(c *gin.Context, relayMode int) *OpenAIErrorWithStatusCode {
}
if quota != 0 {
tokenName := c.GetString("token_name")
logContent := fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent)
var logContent string
if isStable {
logContent = fmt.Sprintf("(稳定模式)模型倍率 %.2f,分组倍率 %.2f", stableRatio, groupRatio)
} else {
logContent = fmt.Sprintf("模型倍率 %.2f,分组倍率 %.2f", modelRatio, groupRatio)
}
model.RecordConsumeLog(userId, promptTokens, completionTokens, textRequest.Model, tokenName, quota, logContent, tokenId)
model.UpdateUserUsedQuotaAndRequestCount(userId, quota)
channelId := c.GetInt("channel_id")
model.UpdateChannelUsedQuota(channelId, quota)

View File

@@ -2,6 +2,7 @@ package controller
import (
"fmt"
"log"
"net/http"
"one-api/common"
"strconv"
@@ -24,6 +25,10 @@ const (
RelayModeModerations
RelayModeImagesGenerations
RelayModeEdits
RelayModeMidjourneyImagine
RelayModeMidjourneyChange
RelayModeMidjourneyNotify
RelayModeMidjourneyTaskFetch
)
// https://platform.openai.com/docs/api-reference/chat
@@ -128,6 +133,23 @@ type CompletionsStreamResponse struct {
} `json:"choices"`
}
type MidjourneyRequest struct {
Prompt string `json:"prompt"`
NotifyHook string `json:"notifyHook"`
Action string `json:"action"`
Index int `json:"index"`
State string `json:"state"`
TaskId string `json:"taskId"`
Base64Array []string `json:"base64Array"`
}
type MidjourneyResponse struct {
Code int `json:"code"`
Description string `json:"description"`
Properties interface{} `json:"properties"`
Result string `json:"result"`
}
func Relay(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/v1/chat/completions") {
@@ -179,6 +201,54 @@ func Relay(c *gin.Context) {
}
}
func RelayMidjourney(c *gin.Context) {
relayMode := RelayModeUnknown
if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/imagine") {
relayMode = RelayModeMidjourneyImagine
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/notify") {
relayMode = RelayModeMidjourneyNotify
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/submit/change") {
relayMode = RelayModeMidjourneyChange
} else if strings.HasPrefix(c.Request.URL.Path, "/mj/task") {
relayMode = RelayModeMidjourneyTaskFetch
}
var err *MidjourneyResponse
switch relayMode {
case RelayModeMidjourneyNotify:
err = relayMidjourneyNotify(c)
case RelayModeMidjourneyTaskFetch:
err = relayMidjourneyTask(c, relayMode)
default:
err = relayMidjourneySubmit(c, relayMode)
}
//err = relayMidjourneySubmit(c, relayMode)
log.Println(err)
if err != nil {
retryTimesStr := c.Query("retry")
retryTimes, _ := strconv.Atoi(retryTimesStr)
if retryTimesStr == "" {
retryTimes = common.RetryTimes
}
if retryTimes > 0 {
c.Redirect(http.StatusTemporaryRedirect, fmt.Sprintf("%s?retry=%d", c.Request.URL.Path, retryTimes-1))
} else {
if err.Code == 30 {
err.Result = "当前分组负载已饱和,请稍后再试,或升级账户以提升服务质量。"
}
c.JSON(400, gin.H{
"error": err.Result,
})
}
channelId := c.GetInt("channel_id")
common.SysError(fmt.Sprintf("relay error (channel #%d): %s", channelId, err.Result))
//if shouldDisableChannel(&err.OpenAIError) {
// channelId := c.GetInt("channel_id")
// channelName := c.GetString("channel_name")
// disableChannel(channelId, channelName, err.Result)
//}
}
}
func RelayNotImplemented(c *gin.Context) {
err := OpenAIError{
Message: "API not implemented",

173
controller/topup.go Normal file
View File

@@ -0,0 +1,173 @@
package controller
import (
"fmt"
"github.com/gin-gonic/gin"
"github.com/samber/lo"
epay "github.com/star-horizon/go-epay"
"log"
"net/url"
"one-api/common"
"one-api/model"
"strconv"
"time"
)
type EpayRequest struct {
Amount int `json:"amount"`
PaymentMethod string `json:"payment_method"`
TopUpCode string `json:"top_up_code"`
}
type AmountRequest struct {
Amount int `json:"amount"`
TopUpCode string `json:"top_up_code"`
}
var client, _ = epay.NewClientWithUrl(&epay.Config{
PartnerID: "1096",
Key: "n08V9LpE8JffA3NPP893689u8p39NV9J",
}, "https://api.lempay.org")
func GetAmount(id int, count float64, topUpCode string) float64 {
amount := count * 1.5
if topUpCode != "" {
if topUpCode == "nekoapi" {
if id == 89 {
amount = count * 1
} else if id == 98 || id == 105 || id == 107 {
amount = count * 1.2
} else if id == 1 {
amount = count * 1
}
}
}
return amount
}
func RequestEpay(c *gin.Context) {
var req EpayRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return
}
id := c.GetInt("id")
amount := GetAmount(id, float64(req.Amount), req.TopUpCode)
if id != 1 {
if req.Amount < 10 {
c.JSON(200, gin.H{"message": "最小充值10元", "data": amount, "count": 10})
return
}
}
if req.PaymentMethod == "zfb" {
if amount > 400 {
c.JSON(200, gin.H{"message": "支付宝最大充值400元", "data": amount, "count": 400})
return
}
req.PaymentMethod = "alipay"
}
if req.PaymentMethod == "wx" {
if amount > 600 {
c.JSON(200, gin.H{"message": "微信最大充值600元", "data": amount, "count": 600})
return
}
req.PaymentMethod = "wxpay"
}
returnUrl, _ := url.Parse("https://nekoapi.com/log")
notifyUrl, _ := url.Parse("https://nekoapi.com/api/user/epay/notify")
tradeNo := strconv.FormatInt(time.Now().Unix(), 10)
uri, params, err := client.Purchase(&epay.PurchaseArgs{
Type: epay.PurchaseType(req.PaymentMethod),
ServiceTradeNo: "A" + tradeNo,
Name: "B" + tradeNo,
Money: strconv.FormatFloat(amount*0.99, 'f', 2, 64),
Device: epay.PC,
NotifyUrl: notifyUrl,
ReturnUrl: returnUrl,
})
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "拉起支付失败"})
return
}
topUp := &model.TopUp{
UserId: id,
Amount: req.Amount,
Money: int(amount),
TradeNo: "A" + tradeNo,
CreateTime: time.Now().Unix(),
Status: "pending",
}
err = topUp.Insert()
if err != nil {
c.JSON(200, gin.H{"message": "error", "data": "创建订单失败"})
return
}
c.JSON(200, gin.H{"message": "success", "data": params, "url": uri})
}
func EpayNotify(c *gin.Context) {
params := lo.Reduce(lo.Keys(c.Request.URL.Query()), func(r map[string]string, t string, i int) map[string]string {
r[t] = c.Request.URL.Query().Get(t)
return r
}, map[string]string{})
verifyInfo, err := client.Verify(params)
if err == nil && verifyInfo.VerifyStatus {
_, err := c.Writer.Write([]byte("success"))
if err != nil {
log.Println("易支付回调写入失败")
}
} else {
_, err := c.Writer.Write([]byte("fail"))
if err != nil {
log.Println("易支付回调写入失败")
}
}
if verifyInfo.TradeStatus == epay.StatusTradeSuccess {
log.Println(verifyInfo)
topUp := model.GetTopUpByTradeNo(verifyInfo.ServiceTradeNo)
if topUp.Status == "pending" {
topUp.Status = "success"
err := topUp.Update()
if err != nil {
log.Printf("易支付回调更新订单失败: %v", topUp)
return
}
//user, _ := model.GetUserById(topUp.UserId, false)
//user.Quota += topUp.Amount * 500000
err = model.IncreaseUserQuota(topUp.UserId, topUp.Amount*500000)
if err != nil {
log.Printf("易支付回调更新用户失败: %v", topUp)
return
}
log.Printf("易支付回调更新用户成功 %v", topUp)
model.RecordLog(topUp.UserId, model.LogTypeTopup, fmt.Sprintf("使用在线充值成功,充值金额: %v", common.LogQuota(topUp.Amount*500000)))
}
} else {
log.Printf("易支付异常回调: %v", verifyInfo)
}
}
func RequestAmount(c *gin.Context) {
var req AmountRequest
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(200, gin.H{"message": err.Error(), "data": 10})
return
}
id := c.GetInt("id")
if id != 1 {
if req.Amount < 10 {
c.JSON(200, gin.H{"message": "最小充值10刀", "data": GetAmount(id, 10, req.TopUpCode), "count": 10})
return
}
if req.Amount > 400 {
c.JSON(200, gin.H{"message": "最大充值400刀", "data": GetAmount(id, 400, req.TopUpCode), "count": 400})
return
}
}
c.JSON(200, gin.H{"message": "success", "data": GetAmount(id, float64(req.Amount), req.TopUpCode)})
}

View File

@@ -3,6 +3,7 @@ package controller
import (
"encoding/json"
"fmt"
"log"
"net/http"
"one-api/common"
"one-api/model"
@@ -79,6 +80,8 @@ func setupLogin(user *model.User, c *gin.Context) {
DisplayName: user.DisplayName,
Role: user.Role,
Status: user.Status,
StableMode: user.StableMode,
MaxPrice: user.MaxPrice,
}
c.JSON(http.StatusOK, gin.H{
"message": "",
@@ -158,6 +161,8 @@ func Register(c *gin.Context) {
Password: user.Password,
DisplayName: user.Username,
InviterId: inviterId,
StableMode: user.StableMode,
MaxPrice: user.MaxPrice,
}
if common.EmailVerificationEnabled {
cleanUser.Email = user.Email
@@ -420,6 +425,8 @@ func UpdateSelf(c *gin.Context) {
Username: user.Username,
Password: user.Password,
DisplayName: user.DisplayName,
StableMode: user.StableMode,
MaxPrice: user.MaxPrice,
}
if user.Password == "$I_LOVE_U" {
user.Password = "" // rollback to what it should be
@@ -741,3 +748,52 @@ func TopUp(c *gin.Context) {
})
return
}
type StableModeRequest struct {
StableMode bool `json:"stableMode"`
MaxPrice string `json:"maxPrice"`
}
func SetTableMode(c *gin.Context) {
req := &StableModeRequest{}
err := c.ShouldBindJSON(&req)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
log.Println(req)
id := c.GetInt("id")
user := model.User{
Id: id,
}
err = user.FillUserById()
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
user.StableMode = req.StableMode
if !req.StableMode {
req.MaxPrice = "0"
}
user.MaxPrice = req.MaxPrice
err = user.Update(false)
if err != nil {
c.JSON(http.StatusOK, gin.H{
"success": false,
"message": err.Error(),
})
return
}
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": "设置成功",
})
return
}