mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-17 21:53:42 +08:00
✨ feat: add Midjourney (#138)
* 🚧 stash * ✨ feat: add Midjourney * 📝 doc: update readme
This commit is contained in:
@@ -114,6 +114,17 @@ func (cc *ChannelsChooser) GetGroupModels(group string) ([]string, error) {
|
||||
return models, nil
|
||||
}
|
||||
|
||||
func (cc *ChannelsChooser) GetChannel(channelId int) *Channel {
|
||||
cc.RLock()
|
||||
defer cc.RUnlock()
|
||||
|
||||
if choice, ok := cc.Channels[channelId]; ok {
|
||||
return choice.Channel
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var ChannelGroup = ChannelsChooser{}
|
||||
|
||||
func (cc *ChannelsChooser) Load() {
|
||||
|
||||
@@ -139,6 +139,10 @@ func InitDB() (err error) {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.AutoMigrate(&Midjourney{})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
common.SysLog("database migrated")
|
||||
err = createRootAccountIfNeed()
|
||||
return err
|
||||
|
||||
182
model/midjourney.go
Normal file
182
model/midjourney.go
Normal file
@@ -0,0 +1,182 @@
|
||||
// Copyright (c) 2024 Calcium-Ion
|
||||
//
|
||||
// Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
// of this software and associated documentation files (the "Software"), to deal
|
||||
// in the Software without restriction, including without limitation the rights
|
||||
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
// copies of the Software, and to permit persons to whom the Software is
|
||||
// furnished to do so, subject to the following conditions:
|
||||
//
|
||||
// The above copyright notice and this permission notice shall be included in all
|
||||
// copies or substantial portions of the Software.
|
||||
//
|
||||
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
// SOFTWARE.
|
||||
//
|
||||
// Author: Calcium-Ion
|
||||
// GitHub: https://github.com/Calcium-Ion/new-api
|
||||
|
||||
package model
|
||||
|
||||
type Midjourney struct {
|
||||
Id int `json:"id"`
|
||||
Code int `json:"code"`
|
||||
UserId int `json:"user_id" gorm:"index"`
|
||||
Action string `json:"action" gorm:"type:varchar(40);index"`
|
||||
MjId string `json:"mj_id" gorm:"index"`
|
||||
Prompt string `json:"prompt"`
|
||||
PromptEn string `json:"prompt_en"`
|
||||
Description string `json:"description"`
|
||||
State string `json:"state"`
|
||||
SubmitTime int64 `json:"submit_time" gorm:"index"`
|
||||
StartTime int64 `json:"start_time" gorm:"index"`
|
||||
FinishTime int64 `json:"finish_time" gorm:"index"`
|
||||
ImageUrl string `json:"image_url"`
|
||||
Status string `json:"status" gorm:"type:varchar(20);index"`
|
||||
Progress string `json:"progress" gorm:"type:varchar(30);index"`
|
||||
FailReason string `json:"fail_reason"`
|
||||
ChannelId int `json:"channel_id"`
|
||||
Quota int `json:"quota"`
|
||||
Buttons string `json:"buttons"`
|
||||
Properties string `json:"properties"`
|
||||
}
|
||||
|
||||
// TaskQueryParams 用于包含所有搜索条件的结构体,可以根据需求添加更多字段
|
||||
type TaskQueryParams struct {
|
||||
ChannelID int `form:"channel_id"`
|
||||
MjID string `form:"mj_id"`
|
||||
StartTimestamp int `form:"start_timestamp"`
|
||||
EndTimestamp int `form:"end_timestamp"`
|
||||
PaginationParams
|
||||
}
|
||||
|
||||
var allowedMidjourneyOrderFields = map[string]bool{
|
||||
"id": true,
|
||||
"user_id": true,
|
||||
"code": true,
|
||||
"action": true,
|
||||
"mj_id": true,
|
||||
"submit_time": true,
|
||||
"start_time": true,
|
||||
"finish_time": true,
|
||||
"status": true,
|
||||
"channel_id": true,
|
||||
}
|
||||
|
||||
func GetAllUserTask(userId int, params *TaskQueryParams) (*DataResult[Midjourney], error) {
|
||||
var tasks []*Midjourney
|
||||
|
||||
// 初始化查询构建器
|
||||
query := DB.Where("user_id = ?", userId)
|
||||
|
||||
if params.MjID != "" {
|
||||
query = query.Where("mj_id = ?", params.MjID)
|
||||
}
|
||||
if params.StartTimestamp != 0 {
|
||||
// 假设您已将前端传来的时间戳转换为数据库所需的时间格式,并处理了时间戳的验证和解析
|
||||
query = query.Where("submit_time >= ?", params.StartTimestamp)
|
||||
}
|
||||
if params.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", params.EndTimestamp)
|
||||
}
|
||||
|
||||
return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields)
|
||||
}
|
||||
|
||||
func GetAllTasks(params *TaskQueryParams) (*DataResult[Midjourney], error) {
|
||||
var tasks []*Midjourney
|
||||
|
||||
// 初始化查询构建器
|
||||
query := DB
|
||||
|
||||
// 添加过滤条件
|
||||
if params.ChannelID != 0 {
|
||||
query = query.Where("channel_id = ?", params.ChannelID)
|
||||
}
|
||||
if params.MjID != "" {
|
||||
query = query.Where("mj_id = ?", params.MjID)
|
||||
}
|
||||
if params.StartTimestamp != 0 {
|
||||
query = query.Where("submit_time >= ?", params.StartTimestamp)
|
||||
}
|
||||
if params.EndTimestamp != 0 {
|
||||
query = query.Where("submit_time <= ?", params.EndTimestamp)
|
||||
}
|
||||
|
||||
return PaginateAndOrder(query, ¶ms.PaginationParams, &tasks, allowedMidjourneyOrderFields)
|
||||
}
|
||||
|
||||
func GetAllUnFinishTasks() []*Midjourney {
|
||||
var tasks []*Midjourney
|
||||
// get all tasks progress is not 100%
|
||||
err := DB.Where("progress != ?", "100%").Find(&tasks).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return tasks
|
||||
}
|
||||
|
||||
func GetByOnlyMJId(mjId string) *Midjourney {
|
||||
var mj *Midjourney
|
||||
err := DB.Where("mj_id = ?", mjId).First(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetByMJId(userId int, mjId string) *Midjourney {
|
||||
var mj *Midjourney
|
||||
err := DB.Where("user_id = ? and mj_id = ?", userId, mjId).First(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetByMJIds(userId int, mjIds []string) []*Midjourney {
|
||||
var mj []*Midjourney
|
||||
err := DB.Where("user_id = ? and mj_id in (?)", userId, mjIds).Find(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func GetMjByuId(id int) *Midjourney {
|
||||
var mj *Midjourney
|
||||
err := DB.Where("id = ?", id).First(&mj).Error
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return mj
|
||||
}
|
||||
|
||||
func UpdateProgress(id int, progress string) error {
|
||||
return DB.Model(&Midjourney{}).Where("id = ?", id).Update("progress", progress).Error
|
||||
}
|
||||
|
||||
func (midjourney *Midjourney) Insert() error {
|
||||
return DB.Create(midjourney).Error
|
||||
}
|
||||
|
||||
func (midjourney *Midjourney) Update() error {
|
||||
return DB.Save(midjourney).Error
|
||||
}
|
||||
|
||||
func MjBulkUpdate(mjIds []string, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("mj_id in (?)", mjIds).
|
||||
Updates(params).Error
|
||||
}
|
||||
|
||||
func MjBulkUpdateByTaskIds(taskIDs []int, params map[string]any) error {
|
||||
return DB.Model(&Midjourney{}).
|
||||
Where("id in (?)", taskIDs).
|
||||
Updates(params).Error
|
||||
}
|
||||
@@ -74,6 +74,8 @@ func InitOptionMap() {
|
||||
common.OptionMap["RetryTimes"] = strconv.Itoa(common.RetryTimes)
|
||||
common.OptionMap["RetryCooldownSeconds"] = strconv.Itoa(common.RetryCooldownSeconds)
|
||||
|
||||
common.OptionMap["MjNotifyEnabled"] = strconv.FormatBool(common.MjNotifyEnabled)
|
||||
|
||||
common.OptionMapRWMutex.Unlock()
|
||||
loadOptionsFromDatabase()
|
||||
}
|
||||
@@ -138,6 +140,7 @@ var optionBoolMap = map[string]*bool{
|
||||
"LogConsumeEnabled": &common.LogConsumeEnabled,
|
||||
"DisplayInCurrencyEnabled": &common.DisplayInCurrencyEnabled,
|
||||
"DisplayTokenStatEnabled": &common.DisplayTokenStatEnabled,
|
||||
"MjNotifyEnabled": &common.MjNotifyEnabled,
|
||||
}
|
||||
|
||||
var optionStringMap = map[string]*string{
|
||||
|
||||
@@ -301,5 +301,33 @@ func GetDefaultPrice() []*Price {
|
||||
})
|
||||
}
|
||||
|
||||
var DefaultMJPrice = map[string]float64{
|
||||
"mj_imagine": 50,
|
||||
"mj_variation": 50,
|
||||
"mj_reroll": 50,
|
||||
"mj_blend": 50,
|
||||
"mj_modal": 50,
|
||||
"mj_zoom": 50,
|
||||
"mj_shorten": 50,
|
||||
"mj_high_variation": 50,
|
||||
"mj_low_variation": 50,
|
||||
"mj_pan": 50,
|
||||
"mj_inpaint": 0,
|
||||
"mj_custom_zoom": 0,
|
||||
"mj_describe": 25,
|
||||
"mj_upscale": 25,
|
||||
"swap_face": 25,
|
||||
}
|
||||
|
||||
for model, mjPrice := range DefaultMJPrice {
|
||||
prices = append(prices, &Price{
|
||||
Model: model,
|
||||
Type: TimesPriceType,
|
||||
ChannelType: common.ChannelTypeMidjourney,
|
||||
Input: mjPrice,
|
||||
Output: mjPrice,
|
||||
})
|
||||
}
|
||||
|
||||
return prices
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user