Compare commits

..

2 Commits

Author SHA1 Message Date
JustSong
b33616df44 feat: support groq now (close #1087) 2024-03-10 14:09:44 +08:00
JustSong
cf16f44970 feat: load channel models from server 2024-03-09 02:28:23 +08:00
13 changed files with 197 additions and 180 deletions

View File

@@ -38,35 +38,38 @@ const (
) )
const ( const (
ChannelTypeUnknown = 0 ChannelTypeUnknown = iota
ChannelTypeOpenAI = 1 ChannelTypeOpenAI
ChannelTypeAPI2D = 2 ChannelTypeAPI2D
ChannelTypeAzure = 3 ChannelTypeAzure
ChannelTypeCloseAI = 4 ChannelTypeCloseAI
ChannelTypeOpenAISB = 5 ChannelTypeOpenAISB
ChannelTypeOpenAIMax = 6 ChannelTypeOpenAIMax
ChannelTypeOhMyGPT = 7 ChannelTypeOhMyGPT
ChannelTypeCustom = 8 ChannelTypeCustom
ChannelTypeAILS = 9 ChannelTypeAILS
ChannelTypeAIProxy = 10 ChannelTypeAIProxy
ChannelTypePaLM = 11 ChannelTypePaLM
ChannelTypeAPI2GPT = 12 ChannelTypeAPI2GPT
ChannelTypeAIGC2D = 13 ChannelTypeAIGC2D
ChannelTypeAnthropic = 14 ChannelTypeAnthropic
ChannelTypeBaidu = 15 ChannelTypeBaidu
ChannelTypeZhipu = 16 ChannelTypeZhipu
ChannelTypeAli = 17 ChannelTypeAli
ChannelTypeXunfei = 18 ChannelTypeXunfei
ChannelType360 = 19 ChannelType360
ChannelTypeOpenRouter = 20 ChannelTypeOpenRouter
ChannelTypeAIProxyLibrary = 21 ChannelTypeAIProxyLibrary
ChannelTypeFastGPT = 22 ChannelTypeFastGPT
ChannelTypeTencent = 23 ChannelTypeTencent
ChannelTypeGemini = 24 ChannelTypeGemini
ChannelTypeMoonshot = 25 ChannelTypeMoonshot
ChannelTypeBaichuan = 26 ChannelTypeBaichuan
ChannelTypeMinimax = 27 ChannelTypeMinimax
ChannelTypeMistral = 28 ChannelTypeMistral
ChannelTypeGroq
ChannelTypeDummy
) )
var ChannelBaseURLs = []string{ var ChannelBaseURLs = []string{
@@ -99,6 +102,7 @@ var ChannelBaseURLs = []string{
"https://api.baichuan-ai.com", // 26 "https://api.baichuan-ai.com", // 26
"https://api.minimax.chat", // 27 "https://api.minimax.chat", // 27
"https://api.mistral.ai", // 28 "https://api.mistral.ai", // 28
"https://api.groq.com/openai", // 29
} }
const ( const (

View File

@@ -125,6 +125,11 @@ var ModelRatio = map[string]float64{
"mistral-medium-latest": 2.7 / 1000 * USD, "mistral-medium-latest": 2.7 / 1000 * USD,
"mistral-large-latest": 8.0 / 1000 * USD, "mistral-large-latest": 8.0 / 1000 * USD,
"mistral-embed": 0.1 / 1000 * USD, "mistral-embed": 0.1 / 1000 * USD,
// https://wow.groq.com/
"llama2-70b-4096": 0.7 / 1000 * USD,
"llama2-7b-2048": 0.1 / 1000 * USD,
"mixtral-8x7b-32768": 0.27 / 1000 * USD,
"gemma-7b-it": 0.1 / 1000 * USD,
} }
var CompletionRatio = map[string]float64{} var CompletionRatio = map[string]float64{}
@@ -209,7 +214,7 @@ func GetCompletionRatio(name string) float64 {
return 2 return 2
} }
} }
return 1.333333 return 4.0 / 3.0
} }
if strings.HasPrefix(name, "gpt-4") { if strings.HasPrefix(name, "gpt-4") {
if strings.HasSuffix(name, "preview") { if strings.HasSuffix(name, "preview") {
@@ -226,5 +231,9 @@ func GetCompletionRatio(name string) float64 {
if strings.HasPrefix(name, "mistral-") { if strings.HasPrefix(name, "mistral-") {
return 3 return 3
} }
switch name {
case "llama2-70b-4096":
return 0.8 / 0.7
}
return 1 return 1
} }

View File

@@ -3,14 +3,13 @@ package controller
import ( import (
"fmt" "fmt"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/relay/channel/ai360" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/baichuan" "github.com/songquanpeng/one-api/relay/channel/openai"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/constant" "github.com/songquanpeng/one-api/relay/constant"
"github.com/songquanpeng/one-api/relay/helper" "github.com/songquanpeng/one-api/relay/helper"
relaymodel "github.com/songquanpeng/one-api/relay/model" relaymodel "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util"
"net/http"
) )
// https://platform.openai.com/docs/api-reference/models/list // https://platform.openai.com/docs/api-reference/models/list
@@ -42,6 +41,7 @@ type OpenAIModels struct {
var openAIModels []OpenAIModels var openAIModels []OpenAIModels
var openAIModelsMap map[string]OpenAIModels var openAIModelsMap map[string]OpenAIModels
var channelId2Models map[int][]string
func init() { func init() {
var permission []OpenAIModelPermission var permission []OpenAIModelPermission
@@ -79,65 +79,44 @@ func init() {
}) })
} }
} }
for _, modelName := range ai360.ModelList { for _, channelType := range openai.CompatibleChannels {
openAIModels = append(openAIModels, OpenAIModels{ if channelType == common.ChannelTypeAzure {
Id: modelName, continue
Object: "model", }
Created: 1626777600, channelName, channelModelList := openai.GetCompatibleChannelMeta(channelType)
OwnedBy: "360", for _, modelName := range channelModelList {
Permission: permission, openAIModels = append(openAIModels, OpenAIModels{
Root: modelName, Id: modelName,
Parent: nil, Object: "model",
}) Created: 1626777600,
} OwnedBy: channelName,
for _, modelName := range moonshot.ModelList { Permission: permission,
openAIModels = append(openAIModels, OpenAIModels{ Root: modelName,
Id: modelName, Parent: nil,
Object: "model", })
Created: 1626777600, }
OwnedBy: "moonshot",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range baichuan.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "baichuan",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range minimax.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "minimax",
Permission: permission,
Root: modelName,
Parent: nil,
})
}
for _, modelName := range mistral.ModelList {
openAIModels = append(openAIModels, OpenAIModels{
Id: modelName,
Object: "model",
Created: 1626777600,
OwnedBy: "mistralai",
Permission: permission,
Root: modelName,
Parent: nil,
})
} }
openAIModelsMap = make(map[string]OpenAIModels) openAIModelsMap = make(map[string]OpenAIModels)
for _, model := range openAIModels { for _, model := range openAIModels {
openAIModelsMap[model.Id] = model openAIModelsMap[model.Id] = model
} }
channelId2Models = make(map[int][]string)
for i := 1; i < common.ChannelTypeDummy; i++ {
adaptor := helper.GetAdaptor(constant.ChannelType2APIType(i))
meta := &util.RelayMeta{
ChannelType: i,
}
adaptor.Init(meta)
channelId2Models[i] = adaptor.GetModelList()
}
}
func DashboardListModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"success": true,
"message": "",
"data": channelId2Models,
})
} }
func ListModels(c *gin.Context) { func ListModels(c *gin.Context) {

View File

@@ -0,0 +1,10 @@
package groq
// https://console.groq.com/docs/models
var ModelList = []string{
"gemma-7b-it",
"llama2-7b-2048",
"llama2-70b-4096",
"mixtral-8x7b-32768",
}

View File

@@ -6,11 +6,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/songquanpeng/one-api/common" "github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel" "github.com/songquanpeng/one-api/relay/channel"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/minimax" "github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
"github.com/songquanpeng/one-api/relay/model" "github.com/songquanpeng/one-api/relay/model"
"github.com/songquanpeng/one-api/relay/util" "github.com/songquanpeng/one-api/relay/util"
"io" "io"
@@ -86,37 +82,11 @@ func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, meta *util.Rel
} }
func (a *Adaptor) GetModelList() []string { func (a *Adaptor) GetModelList() []string {
switch a.ChannelType { _, modelList := GetCompatibleChannelMeta(a.ChannelType)
case common.ChannelType360: return modelList
return ai360.ModelList
case common.ChannelTypeMoonshot:
return moonshot.ModelList
case common.ChannelTypeBaichuan:
return baichuan.ModelList
case common.ChannelTypeMinimax:
return minimax.ModelList
case common.ChannelTypeMistral:
return mistral.ModelList
default:
return ModelList
}
} }
func (a *Adaptor) GetChannelName() string { func (a *Adaptor) GetChannelName() string {
switch a.ChannelType { channelName, _ := GetCompatibleChannelMeta(a.ChannelType)
case common.ChannelTypeAzure: return channelName
return "azure"
case common.ChannelType360:
return "360"
case common.ChannelTypeMoonshot:
return "moonshot"
case common.ChannelTypeBaichuan:
return "baichuan"
case common.ChannelTypeMinimax:
return "minimax"
case common.ChannelTypeMistral:
return "mistralai"
default:
return "openai"
}
} }

View File

@@ -0,0 +1,42 @@
package openai
import (
"github.com/songquanpeng/one-api/common"
"github.com/songquanpeng/one-api/relay/channel/ai360"
"github.com/songquanpeng/one-api/relay/channel/baichuan"
"github.com/songquanpeng/one-api/relay/channel/groq"
"github.com/songquanpeng/one-api/relay/channel/minimax"
"github.com/songquanpeng/one-api/relay/channel/mistral"
"github.com/songquanpeng/one-api/relay/channel/moonshot"
)
var CompatibleChannels = []int{
common.ChannelTypeAzure,
common.ChannelType360,
common.ChannelTypeMoonshot,
common.ChannelTypeBaichuan,
common.ChannelTypeMinimax,
common.ChannelTypeMistral,
common.ChannelTypeGroq,
}
func GetCompatibleChannelMeta(channelType int) (string, []string) {
switch channelType {
case common.ChannelTypeAzure:
return "azure", ModelList
case common.ChannelType360:
return "360", ai360.ModelList
case common.ChannelTypeMoonshot:
return "moonshot", moonshot.ModelList
case common.ChannelTypeBaichuan:
return "baichuan", baichuan.ModelList
case common.ChannelTypeMinimax:
return "minimax", minimax.ModelList
case common.ChannelTypeMistral:
return "mistralai", mistral.ModelList
case common.ChannelTypeGroq:
return "groq", groq.ModelList
default:
return "openai", ModelList
}
}

View File

@@ -14,6 +14,7 @@ func SetApiRouter(router *gin.Engine) {
apiRouter.Use(middleware.GlobalAPIRateLimit()) apiRouter.Use(middleware.GlobalAPIRateLimit())
{ {
apiRouter.GET("/status", controller.GetStatus) apiRouter.GET("/status", controller.GetStatus)
apiRouter.GET("/models", middleware.UserAuth(), controller.DashboardListModels)
apiRouter.GET("/notice", controller.GetNotice) apiRouter.GET("/notice", controller.GetNotice)
apiRouter.GET("/about", controller.GetAbout) apiRouter.GET("/about", controller.GetAbout)
apiRouter.GET("/home_page_content", controller.GetHomePageContent) apiRouter.GET("/home_page_content", controller.GetHomePageContent)

View File

@@ -15,7 +15,7 @@ export const CHANNEL_OPTIONS = {
key: 3, key: 3,
text: 'Azure OpenAI', text: 'Azure OpenAI',
value: 3, value: 3,
color: 'orange' color: 'secondary'
}, },
11: { 11: {
key: 11, key: 11,
@@ -89,6 +89,12 @@ export const CHANNEL_OPTIONS = {
value: 27, value: 27,
color: 'default' color: 'default'
}, },
29: {
key: 29,
text: 'Groq',
value: 29,
color: 'default'
},
8: { 8: {
key: 8, key: 8,
text: '自定义渠道', text: '自定义渠道',

View File

@@ -163,6 +163,9 @@ const typeConfig = {
}, },
modelGroup: "minimax", modelGroup: "minimax",
}, },
29: {
modelGroup: "groq",
},
}; };
export { defaultConfig, typeConfig }; export { defaultConfig, typeConfig };

View File

@@ -1,7 +1,16 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react'; import { Button, Form, Input, Label, Message, Pagination, Popup, Table } from 'semantic-ui-react';
import { Link } from 'react-router-dom'; import { Link } from 'react-router-dom';
import { API, setPromptShown, shouldShowPrompt, showError, showInfo, showSuccess, timestamp2string } from '../helpers'; import {
API,
loadChannelModels,
setPromptShown,
shouldShowPrompt,
showError,
showInfo,
showSuccess,
timestamp2string
} from '../helpers';
import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants'; import { CHANNEL_OPTIONS, ITEMS_PER_PAGE } from '../constants';
import { renderGroup, renderNumber } from '../helpers/render'; import { renderGroup, renderNumber } from '../helpers/render';
@@ -95,6 +104,7 @@ const ChannelsTable = () => {
.catch((reason) => { .catch((reason) => {
showError(reason); showError(reason);
}); });
loadChannelModels().then();
}, []); }, []);
const manageChannel = async (id, action, idx, value) => { const manageChannel = async (id, action, idx, value) => {

View File

@@ -14,6 +14,7 @@ export const CHANNEL_OPTIONS = [
{ key: 23, text: '腾讯混元', value: 23, color: 'teal' }, { key: 23, text: '腾讯混元', value: 23, color: 'teal' },
{ key: 26, text: '百川大模型', value: 26, color: 'orange' }, { key: 26, text: '百川大模型', value: 26, color: 'orange' },
{ key: 27, text: 'MiniMax', value: 27, color: 'red' }, { key: 27, text: 'MiniMax', value: 27, color: 'red' },
{ key: 29, text: 'Groq', value: 29, color: 'orange' },
{ key: 8, text: '自定义渠道', value: 8, color: 'pink' }, { key: 8, text: '自定义渠道', value: 8, color: 'pink' },
{ key: 22, text: '知识库FastGPT', value: 22, color: 'blue' }, { key: 22, text: '知识库FastGPT', value: 22, color: 'blue' },
{ key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' }, { key: 21, text: '知识库AI Proxy', value: 21, color: 'purple' },

View File

@@ -1,11 +1,13 @@
import { toast } from 'react-toastify'; import { toast } from 'react-toastify';
import { toastConstants } from '../constants'; import { toastConstants } from '../constants';
import React from 'react'; import React from 'react';
import { API } from './api';
const HTMLToastContent = ({ htmlContent }) => { const HTMLToastContent = ({ htmlContent }) => {
return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />; return <div dangerouslySetInnerHTML={{ __html: htmlContent }} />;
}; };
export default HTMLToastContent; export default HTMLToastContent;
export function isAdmin() { export function isAdmin() {
let user = localStorage.getItem('user'); let user = localStorage.getItem('user');
if (!user) return false; if (!user) return false;
@@ -29,7 +31,7 @@ export function getSystemName() {
export function getLogo() { export function getLogo() {
let logo = localStorage.getItem('logo'); let logo = localStorage.getItem('logo');
if (!logo) return '/logo.png'; if (!logo) return '/logo.png';
return logo return logo;
} }
export function getFooterHTML() { export function getFooterHTML() {
@@ -197,3 +199,29 @@ export function shouldShowPrompt(id) {
export function setPromptShown(id) { export function setPromptShown(id) {
localStorage.setItem(`prompt-${id}`, 'true'); localStorage.setItem(`prompt-${id}`, 'true');
} }
let channelModels = undefined;
export async function loadChannelModels() {
const res = await API.get('/api/models');
const { success, data } = res.data;
if (!success) {
return;
}
channelModels = data;
localStorage.setItem('channel_models', JSON.stringify(data));
}
export function getChannelModels(type) {
if (channelModels !== undefined && type in channelModels) {
return channelModels[type];
}
let models = localStorage.getItem('channel_models');
if (!models) {
return [];
}
channelModels = JSON.parse(models);
if (type in channelModels) {
return channelModels[type];
}
return [];
}

View File

@@ -1,7 +1,7 @@
import React, { useEffect, useState } from 'react'; import React, { useEffect, useState } from 'react';
import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react'; import { Button, Form, Header, Input, Message, Segment } from 'semantic-ui-react';
import { useNavigate, useParams } from 'react-router-dom'; import { useNavigate, useParams } from 'react-router-dom';
import { API, showError, showInfo, showSuccess, verifyJSON } from '../../helpers'; import { API, copy, getChannelModels, showError, showInfo, showSuccess, verifyJSON } from '../../helpers';
import { CHANNEL_OPTIONS } from '../../constants'; import { CHANNEL_OPTIONS } from '../../constants';
const MODEL_MAPPING_EXAMPLE = { const MODEL_MAPPING_EXAMPLE = {
@@ -56,60 +56,12 @@ const EditChannel = () => {
const [customModel, setCustomModel] = useState(''); const [customModel, setCustomModel] = useState('');
const handleInputChange = (e, { name, value }) => { const handleInputChange = (e, { name, value }) => {
setInputs((inputs) => ({ ...inputs, [name]: value })); setInputs((inputs) => ({ ...inputs, [name]: value }));
if (name === 'type' && inputs.models.length === 0) { if (name === 'type') {
let localModels = []; let localModels = getChannelModels(value);
switch (value) { if (inputs.models.length === 0) {
case 14: setInputs((inputs) => ({ ...inputs, models: localModels }));
localModels = ['claude-instant-1', 'claude-2', 'claude-2.0', 'claude-2.1'];
break;
case 11:
localModels = ['PaLM-2'];
break;
case 15:
localModels = ['ERNIE-Bot', 'ERNIE-Bot-turbo', 'ERNIE-Bot-4', 'Embedding-V1'];
break;
case 17:
localModels = ['qwen-turbo', 'qwen-plus', 'qwen-max', 'qwen-max-longcontext', 'text-embedding-v1'];
let withInternetVersion = [];
for (let i = 0; i < localModels.length; i++) {
if (localModels[i].startsWith('qwen-')) {
withInternetVersion.push(localModels[i] + '-internet');
}
}
localModels = [...localModels, ...withInternetVersion];
break;
case 16:
localModels = ["glm-4", "glm-4v", "glm-3-turbo",'chatglm_turbo', 'chatglm_pro', 'chatglm_std', 'chatglm_lite'];
break;
case 18:
localModels = [
'SparkDesk',
'SparkDesk-v1.1',
'SparkDesk-v2.1',
'SparkDesk-v3.1',
'SparkDesk-v3.5'
];
break;
case 19:
localModels = ['360GPT_S2_V9', 'embedding-bert-512-v1', 'embedding_s1_v1', 'semantic_similarity_s1_v1'];
break;
case 23:
localModels = ['hunyuan'];
break;
case 24:
localModels = ['gemini-pro', 'gemini-pro-vision'];
break;
case 25:
localModels = ['moonshot-v1-8k', 'moonshot-v1-32k', 'moonshot-v1-128k'];
break;
case 26:
localModels = ['Baichuan2-Turbo', 'Baichuan2-Turbo-192k', 'Baichuan-Text-Embedding'];
break;
case 27:
localModels = ['abab5.5s-chat', 'abab5.5-chat', 'abab6-chat'];
break;
} }
setInputs((inputs) => ({ ...inputs, models: localModels })); setBasicModels(localModels);
} }
}; };
@@ -390,6 +342,8 @@ const EditChannel = () => {
required required
fluid fluid
multiple multiple
search
onLabelClick={(e, { value }) => {copy(value).then()}}
selection selection
onChange={handleInputChange} onChange={handleInputChange}
value={inputs.models} value={inputs.models}