mirror of
https://github.com/songquanpeng/one-api.git
synced 2025-11-16 21:23:44 +08:00
✨ feat: add Stability AI
This commit is contained in:
@@ -21,6 +21,7 @@ import (
|
||||
"one-api/providers/mistral"
|
||||
"one-api/providers/openai"
|
||||
"one-api/providers/palm"
|
||||
"one-api/providers/stabilityAI"
|
||||
"one-api/providers/tencent"
|
||||
"one-api/providers/xunfei"
|
||||
"one-api/providers/zhipu"
|
||||
@@ -58,6 +59,7 @@ func init() {
|
||||
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
|
||||
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
|
||||
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
|
||||
providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
|
||||
|
||||
}
|
||||
|
||||
|
||||
79
providers/stabilityAI/base.go
Normal file
79
providers/stabilityAI/base.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package stabilityAI
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"one-api/common/requester"
|
||||
"one-api/model"
|
||||
"one-api/providers/base"
|
||||
"one-api/types"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type StabilityAIProviderFactory struct{}
|
||||
|
||||
// 创建 StabilityAIProvider
|
||||
func (f StabilityAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
|
||||
return &StabilityAIProvider{
|
||||
BaseProvider: base.BaseProvider{
|
||||
Config: getConfig(),
|
||||
Channel: channel,
|
||||
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
type StabilityAIProvider struct {
|
||||
base.BaseProvider
|
||||
}
|
||||
|
||||
func getConfig() base.ProviderConfig {
|
||||
return base.ProviderConfig{
|
||||
BaseURL: "https://api.stability.ai/v2beta",
|
||||
ImagesGenerations: "/stable-image/generate",
|
||||
}
|
||||
}
|
||||
|
||||
// 请求错误处理
|
||||
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
|
||||
stabilityAIError := &StabilityAIError{}
|
||||
err := json.NewDecoder(resp.Body).Decode(stabilityAIError)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errorHandle(stabilityAIError)
|
||||
}
|
||||
|
||||
// 错误处理
|
||||
func errorHandle(stabilityAIError *StabilityAIError) *types.OpenAIError {
|
||||
openaiError := &types.OpenAIError{
|
||||
Type: "stabilityAI_error",
|
||||
}
|
||||
|
||||
if stabilityAIError.Name != "" {
|
||||
openaiError.Message = stabilityAIError.String()
|
||||
openaiError.Code = stabilityAIError.Name
|
||||
} else {
|
||||
openaiError.Message = stabilityAIError.Message
|
||||
openaiError.Code = "stabilityAI_error"
|
||||
}
|
||||
|
||||
return openaiError
|
||||
}
|
||||
|
||||
func (p *StabilityAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
|
||||
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
|
||||
|
||||
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
func (p *StabilityAIProvider) GetRequestHeaders() (headers map[string]string) {
|
||||
headers = make(map[string]string)
|
||||
p.CommonRequestHeaders(headers)
|
||||
headers["Authorization"] = "Bearer " + p.Channel.Key
|
||||
|
||||
return headers
|
||||
}
|
||||
87
providers/stabilityAI/image_generations.go
Normal file
87
providers/stabilityAI/image_generations.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package stabilityAI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"one-api/common"
|
||||
"one-api/common/storage"
|
||||
"one-api/types"
|
||||
"time"
|
||||
)
|
||||
|
||||
func convertModelName(modelName string) string {
|
||||
if modelName == "stable-image-core" {
|
||||
return "core"
|
||||
}
|
||||
|
||||
return "sd3"
|
||||
}
|
||||
|
||||
func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
|
||||
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
// 获取请求地址
|
||||
fullRequestURL := p.GetFullRequestURL(url, convertModelName(request.Model))
|
||||
if fullRequestURL == "" {
|
||||
return nil, common.ErrorWrapper(nil, "invalid_stabilityAI_config", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
// 获取请求头
|
||||
headers := p.GetRequestHeaders()
|
||||
headers["Accept"] = "application/json; type=image/png"
|
||||
|
||||
var formBody bytes.Buffer
|
||||
builder := p.Requester.CreateFormBuilder(&formBody)
|
||||
builder.WriteField("prompt", request.Prompt)
|
||||
builder.WriteField("output_format", "png")
|
||||
if request.Model != "stable-image-core" {
|
||||
builder.WriteField("model", request.Model)
|
||||
}
|
||||
builder.Close()
|
||||
|
||||
req, err := p.Requester.NewRequest(
|
||||
http.MethodPost,
|
||||
fullRequestURL,
|
||||
p.Requester.WithBody(&formBody),
|
||||
p.Requester.WithHeader(headers),
|
||||
p.Requester.WithContentType(builder.FormDataContentType()))
|
||||
req.ContentLength = int64(formBody.Len())
|
||||
|
||||
if err != nil {
|
||||
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
|
||||
}
|
||||
|
||||
stabilityAIResponse := &generateResponse{}
|
||||
|
||||
// 发送请求
|
||||
_, errWithCode = p.Requester.SendRequest(req, stabilityAIResponse, false)
|
||||
if errWithCode != nil {
|
||||
return nil, errWithCode
|
||||
}
|
||||
|
||||
openaiResponse := &types.ImageResponse{
|
||||
Created: time.Now().Unix(),
|
||||
}
|
||||
|
||||
imgUrl := ""
|
||||
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
|
||||
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
|
||||
if err == nil {
|
||||
imgUrl = storage.Upload(body, common.GetUUID()+".png")
|
||||
}
|
||||
}
|
||||
|
||||
if imgUrl == "" {
|
||||
openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: stabilityAIResponse.Image}}
|
||||
} else {
|
||||
openaiResponse.Data = []types.ImageResponseDataInner{{URL: imgUrl}}
|
||||
}
|
||||
|
||||
p.Usage.PromptTokens = 1000
|
||||
|
||||
return openaiResponse, nil
|
||||
}
|
||||
20
providers/stabilityAI/type.go
Normal file
20
providers/stabilityAI/type.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package stabilityAI
|
||||
|
||||
import "strings"
|
||||
|
||||
type StabilityAIError struct {
|
||||
Name string `json:"name,omitempty"`
|
||||
Errors []string `json:"errors,omitempty"`
|
||||
Success bool `json:"success,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
func (e StabilityAIError) String() string {
|
||||
return strings.Join(e.Errors, ", ")
|
||||
}
|
||||
|
||||
type generateResponse struct {
|
||||
Image string `json:"image"`
|
||||
FinishReason string `json:"finish_reason,omitempty"`
|
||||
Seed int `json:"seed,omitempty"`
|
||||
}
|
||||
Reference in New Issue
Block a user