mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-14 12:23:41 +08:00
✨ add: add images edits and variations API
This commit is contained in:
@@ -146,7 +146,7 @@ func (p *BaseProvider) SupportAPI(relayMode int) bool {
|
||||
return p.Moderation != ""
|
||||
case common.RelayModeImagesGenerations:
|
||||
return p.ImagesGenerations != ""
|
||||
case common.RelayModeImagesEdit:
|
||||
case common.RelayModeImagesEdits:
|
||||
return p.ImagesEdit != ""
|
||||
case common.RelayModeImagesVariations:
|
||||
return p.ImagesVariations != ""
|
||||
|
||||
@@ -56,11 +56,23 @@ type TranslationInterface interface {
|
||||
TranslationAction(request *types.AudioRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片生成接口
|
||||
type ImageGenerationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageGenerationsAction(request *types.ImageRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 图片编辑接口
|
||||
type ImageEditsInterface interface {
|
||||
ProviderInterface
|
||||
ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
type ImageVariationsInterface interface {
|
||||
ProviderInterface
|
||||
ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode)
|
||||
}
|
||||
|
||||
// 余额接口
|
||||
type BalanceInterface interface {
|
||||
BalanceAction(channel *model.Channel) (float64, error)
|
||||
|
||||
@@ -39,7 +39,7 @@ func CreateOpenAIProvider(c *gin.Context, baseURL string) *OpenAIProvider {
|
||||
AudioTranscriptions: "/v1/audio/transcriptions",
|
||||
AudioTranslations: "/v1/audio/translations",
|
||||
ImagesGenerations: "/v1/images/generations",
|
||||
ImagesEdit: "/v1/images/edit",
|
||||
ImagesEdit: "/v1/images/edits",
|
||||
ImagesVariations: "/v1/images/variations",
|
||||
Context: c,
|
||||
},
|
||||
|
||||
104
providers/openai/image_edits.go
Normal file
104
providers/openai/image_edits.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageEditsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesEdit, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func imagesEditsMultipartForm(request *types.ImageEditRequest, b common.FormBuilder) error {
|
||||
err := b.CreateFormFile("image", request.Image)
|
||||
if err != nil {
|
||||
return fmt.Errorf("creating form file: %w", err)
|
||||
}
|
||||
|
||||
err = b.WriteField("prompt", request.Prompt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing prompt: %w", err)
|
||||
}
|
||||
|
||||
err = b.WriteField("model", request.Model)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing model name: %w", err)
|
||||
}
|
||||
|
||||
if request.Mask != nil {
|
||||
err = b.CreateFormFile("mask", request.Mask)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.ResponseFormat != "" {
|
||||
err = b.WriteField("response_format", request.ResponseFormat)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing format: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.N != 0 {
|
||||
err = b.WriteField("n", fmt.Sprintf("%.2f", request.N))
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing temperature: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.Size != "" {
|
||||
err = b.WriteField("size", request.Size)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing language: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if request.User != "" {
|
||||
err = b.WriteField("user", request.User)
|
||||
if err != nil {
|
||||
return fmt.Errorf("writing language: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return b.Close()
|
||||
}
|
||||
49
providers/openai/image_variations.go
Normal file
49
providers/openai/image_variations.go
Normal file
@@ -0,0 +1,49 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/types"
|
||||
)
|
||||
|
||||
func (p *OpenAIProvider) ImageVariationsAction(request *types.ImageEditRequest, isModelMapped bool, promptTokens int) (usage *types.Usage, errWithCode *types.OpenAIErrorWithStatusCode) {
|
||||
fullRequestURL := p.GetFullRequestURL(p.ImagesVariations, request.Model)
|
||||
headers := p.GetRequestHeaders()
|
||||
|
||||
client := common.NewClient()
|
||||
|
||||
var formBody bytes.Buffer
|
||||
var req *http.Request
|
||||
var err error
|
||||
if isModelMapped {
|
||||
builder := client.CreateFormBuilder(&formBody)
|
||||
if err := imagesEditsMultipartForm(request, builder); err != nil {
|
||||
return nil, types.ErrorWrapper(err, "create_form_builder_failed", http.StatusInternalServerError)
|
||||
}
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(&formBody), common.WithHeader(headers), common.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
} else {
|
||||
req, err = client.NewRequest(p.Context.Request.Method, fullRequestURL, common.WithBody(p.Context.Request.Body), common.WithHeader(headers), common.WithContentType(p.Context.Request.Header.Get("Content-Type")))
|
||||
req.ContentLength = p.Context.Request.ContentLength
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, types.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
openAIProviderImageResponseResponse := &OpenAIProviderImageResponseResponse{}
|
||||
errWithCode = p.SendRequest(req, openAIProviderImageResponseResponse, true)
|
||||
if errWithCode != nil {
|
||||
return
|
||||
}
|
||||
|
||||
usage = &types.Usage{
|
||||
PromptTokens: promptTokens,
|
||||
CompletionTokens: 0,
|
||||
TotalTokens: promptTokens,
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
Reference in New Issue
Block a user