feat: add Stability AI

This commit is contained in:
Martial BE
2024-04-18 18:53:49 +08:00
parent b20659dfcc
commit 303fe3360b
10 changed files with 216 additions and 0 deletions

View File

@@ -21,6 +21,7 @@ import (
"one-api/providers/mistral"
"one-api/providers/openai"
"one-api/providers/palm"
"one-api/providers/stabilityAI"
"one-api/providers/tencent"
"one-api/providers/xunfei"
"one-api/providers/zhipu"
@@ -58,6 +59,7 @@ func init() {
providerFactories[common.ChannelTypeMidjourney] = midjourney.MidjourneyProviderFactory{}
providerFactories[common.ChannelTypeCloudflareAI] = cloudflareAI.CloudflareAIProviderFactory{}
providerFactories[common.ChannelTypeCohere] = cohere.CohereProviderFactory{}
providerFactories[common.ChannelTypeStabilityAI] = stabilityAI.StabilityAIProviderFactory{}
}

View File

@@ -0,0 +1,79 @@
package stabilityAI
import (
"encoding/json"
"fmt"
"net/http"
"one-api/common/requester"
"one-api/model"
"one-api/providers/base"
"one-api/types"
"strings"
)
type StabilityAIProviderFactory struct{}
// 创建 StabilityAIProvider
func (f StabilityAIProviderFactory) Create(channel *model.Channel) base.ProviderInterface {
return &StabilityAIProvider{
BaseProvider: base.BaseProvider{
Config: getConfig(),
Channel: channel,
Requester: requester.NewHTTPRequester(*channel.Proxy, requestErrorHandle),
},
}
}
type StabilityAIProvider struct {
base.BaseProvider
}
func getConfig() base.ProviderConfig {
return base.ProviderConfig{
BaseURL: "https://api.stability.ai/v2beta",
ImagesGenerations: "/stable-image/generate",
}
}
// 请求错误处理
func requestErrorHandle(resp *http.Response) *types.OpenAIError {
stabilityAIError := &StabilityAIError{}
err := json.NewDecoder(resp.Body).Decode(stabilityAIError)
if err != nil {
return nil
}
return errorHandle(stabilityAIError)
}
// 错误处理
func errorHandle(stabilityAIError *StabilityAIError) *types.OpenAIError {
openaiError := &types.OpenAIError{
Type: "stabilityAI_error",
}
if stabilityAIError.Name != "" {
openaiError.Message = stabilityAIError.String()
openaiError.Code = stabilityAIError.Name
} else {
openaiError.Message = stabilityAIError.Message
openaiError.Code = "stabilityAI_error"
}
return openaiError
}
func (p *StabilityAIProvider) GetFullRequestURL(requestURL string, modelName string) string {
baseURL := strings.TrimSuffix(p.GetBaseURL(), "/")
return fmt.Sprintf("%s%s/%s", baseURL, requestURL, modelName)
}
// 获取请求头
func (p *StabilityAIProvider) GetRequestHeaders() (headers map[string]string) {
headers = make(map[string]string)
p.CommonRequestHeaders(headers)
headers["Authorization"] = "Bearer " + p.Channel.Key
return headers
}

View File

@@ -0,0 +1,87 @@
package stabilityAI
import (
"bytes"
"encoding/base64"
"net/http"
"one-api/common"
"one-api/common/storage"
"one-api/types"
"time"
)
func convertModelName(modelName string) string {
if modelName == "stable-image-core" {
return "core"
}
return "sd3"
}
func (p *StabilityAIProvider) CreateImageGenerations(request *types.ImageRequest) (*types.ImageResponse, *types.OpenAIErrorWithStatusCode) {
url, errWithCode := p.GetSupportedAPIUri(common.RelayModeImagesGenerations)
if errWithCode != nil {
return nil, errWithCode
}
// 获取请求地址
fullRequestURL := p.GetFullRequestURL(url, convertModelName(request.Model))
if fullRequestURL == "" {
return nil, common.ErrorWrapper(nil, "invalid_stabilityAI_config", http.StatusInternalServerError)
}
// 获取请求头
headers := p.GetRequestHeaders()
headers["Accept"] = "application/json; type=image/png"
var formBody bytes.Buffer
builder := p.Requester.CreateFormBuilder(&formBody)
builder.WriteField("prompt", request.Prompt)
builder.WriteField("output_format", "png")
if request.Model != "stable-image-core" {
builder.WriteField("model", request.Model)
}
builder.Close()
req, err := p.Requester.NewRequest(
http.MethodPost,
fullRequestURL,
p.Requester.WithBody(&formBody),
p.Requester.WithHeader(headers),
p.Requester.WithContentType(builder.FormDataContentType()))
req.ContentLength = int64(formBody.Len())
if err != nil {
return nil, common.ErrorWrapper(err, "new_request_failed", http.StatusInternalServerError)
}
stabilityAIResponse := &generateResponse{}
// 发送请求
_, errWithCode = p.Requester.SendRequest(req, stabilityAIResponse, false)
if errWithCode != nil {
return nil, errWithCode
}
openaiResponse := &types.ImageResponse{
Created: time.Now().Unix(),
}
imgUrl := ""
if request.ResponseFormat == "" || request.ResponseFormat == "url" {
body, err := base64.StdEncoding.DecodeString(stabilityAIResponse.Image)
if err == nil {
imgUrl = storage.Upload(body, common.GetUUID()+".png")
}
}
if imgUrl == "" {
openaiResponse.Data = []types.ImageResponseDataInner{{B64JSON: stabilityAIResponse.Image}}
} else {
openaiResponse.Data = []types.ImageResponseDataInner{{URL: imgUrl}}
}
p.Usage.PromptTokens = 1000
return openaiResponse, nil
}

View File

@@ -0,0 +1,20 @@
package stabilityAI
import "strings"
type StabilityAIError struct {
Name string `json:"name,omitempty"`
Errors []string `json:"errors,omitempty"`
Success bool `json:"success,omitempty"`
Message string `json:"message,omitempty"`
}
func (e StabilityAIError) String() string {
return strings.Join(e.Errors, ", ")
}
type generateResponse struct {
Image string `json:"image"`
FinishReason string `json:"finish_reason,omitempty"`
Seed int `json:"seed,omitempty"`
}