merge v4.2.6

整合 v4.2.6 的后端中间件与服务层重构、前端样式体系迁移和管理端/移动端功能更新,统一清理历史冲突并完成版本升级。

Made-with: Cursor
This commit is contained in:
RockYang
2026-04-08 15:08:34 +08:00
390 changed files with 35519 additions and 25073 deletions

View File

@@ -8,35 +8,38 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core/types"
"github.com/imroc/req/v3"
"time"
"github.com/imroc/req/v3"
)
type CaptchaService struct {
config types.ApiConfig
config types.CaptchaConfig
client *req.Client
}
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
func NewCaptchaService(captchaConfig types.CaptchaConfig) *CaptchaService {
return &CaptchaService{
config: config,
config: captchaConfig,
client: req.C().SetTimeout(10 * time.Second),
}
}
func (s *CaptchaService) UpdateConfig(config types.CaptchaConfig) {
s.config = config
}
func (s *CaptchaService) GetConfig() types.CaptchaConfig {
return s.config
}
func (s *CaptchaService) Get() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/get", s.config.ApiURL)
url := fmt.Sprintf("%s/api/captcha/get", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
@@ -49,12 +52,11 @@ func (s *CaptchaService) Get() (interface{}, error) {
return res.Data, nil
}
func (s *CaptchaService) Check(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/check", s.config.ApiURL)
func (s *CaptchaService) Check(data any) bool {
url := fmt.Sprintf("%s/api/captcha/check", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
@@ -68,16 +70,11 @@ func (s *CaptchaService) Check(data interface{}) bool {
return true
}
func (s *CaptchaService) SlideGet() (interface{}, error) {
if s.config.Token == "" {
return nil, errors.New("无效的 API Token")
}
url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
func (s *CaptchaService) SlideGet() (any, error) {
url := fmt.Sprintf("%s/api/captcha/slide/get", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetSuccessResult(&res).Get(url)
if err != nil || r.IsErrorState() {
return nil, fmt.Errorf("请求 API 失败:%v", err)
@@ -90,12 +87,11 @@ func (s *CaptchaService) SlideGet() (interface{}, error) {
return res.Data, nil
}
func (s *CaptchaService) SlideCheck(data interface{}) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
func (s *CaptchaService) SlideCheck(data any) bool {
url := fmt.Sprintf("%s/api/captcha/slide/check", types.GeekAPIURL)
var res types.BizVo
r, err := s.client.R().
SetHeader("AppId", s.config.AppId).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.ApiKey)).
SetBodyJsonMarshal(data).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {

View File

@@ -1,333 +0,0 @@
package crawler
import (
"context"
"errors"
"fmt"
"geekai/logger"
"net/url"
"strings"
"time"
"github.com/go-rod/rod"
"github.com/go-rod/rod/lib/launcher"
"github.com/go-rod/rod/lib/proto"
)
// Service 网络爬虫服务
type Service struct {
browser *rod.Browser
}
// NewService 创建一个新的爬虫服务
func NewService() (*Service, error) {
// 启动浏览器
path, _ := launcher.LookPath()
u := launcher.New().Bin(path).
Headless(true). // 无头模式
Set("disable-web-security", ""). // 禁用网络安全限制
Set("disable-gpu", ""). // 禁用 GPU 加速
Set("no-sandbox", ""). // 禁用沙箱模式
Set("disable-setuid-sandbox", ""). // 禁用 setuid 沙箱
MustLaunch()
browser := rod.New().ControlURL(u).MustConnect()
return &Service{
browser: browser,
}, nil
}
// SearchResult 搜索结果
type SearchResult struct {
Title string `json:"title"` // 标题
URL string `json:"url"` // 链接
Content string `json:"content"` // 内容摘要
}
// WebSearch 网络搜索
func (s *Service) WebSearch(keyword string, maxPages int) ([]SearchResult, error) {
if keyword == "" {
return nil, errors.New("搜索关键词不能为空")
}
if maxPages <= 0 {
maxPages = 1
}
if maxPages > 10 {
maxPages = 10 // 最多搜索 10 页
}
results := make([]SearchResult, 0)
// 使用百度搜索
searchURL := fmt.Sprintf("https://www.baidu.com/s?wd=%s", url.QueryEscape(keyword))
// 设置页面超时
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
defer cancel()
// 创建页面
page := s.browser.MustPage()
defer page.MustClose()
// 设置视口大小
err := page.SetViewport(&proto.EmulationSetDeviceMetricsOverride{
Width: 1280,
Height: 800,
})
if err != nil {
return nil, fmt.Errorf("设置视口失败: %v", err)
}
// 导航到搜索页面
err = page.Context(ctx).Navigate(searchURL)
if err != nil {
return nil, fmt.Errorf("导航到搜索页面失败: %v", err)
}
// 等待搜索结果加载完成
err = page.WaitLoad()
if err != nil {
return nil, fmt.Errorf("等待页面加载完成失败: %v", err)
}
// 分析当前页面的搜索结果
for i := 0; i < maxPages; i++ {
if i > 0 {
// 点击下一页按钮
nextPage, err := page.Element("a.n")
if err != nil || nextPage == nil {
break // 没有下一页
}
err = nextPage.Click(proto.InputMouseButtonLeft, 1)
if err != nil {
break // 点击下一页失败
}
// 等待新页面加载
err = page.WaitLoad()
if err != nil {
break
}
}
// 提取搜索结果
resultElements, err := page.Elements(".result, .c-container")
if err != nil || resultElements == nil {
continue
}
for _, result := range resultElements {
// 获取标题
titleElement, err := result.Element("h3, .t")
if err != nil || titleElement == nil {
continue
}
title, err := titleElement.Text()
if err != nil {
continue
}
// 获取 URL
linkElement, err := titleElement.Element("a")
if err != nil || linkElement == nil {
continue
}
href, err := linkElement.Attribute("href")
if err != nil || href == nil {
continue
}
// 获取内容摘要 - 尝试多个可能的选择器
var contentElement *rod.Element
var content string
// 尝试多个可能的选择器来适应不同版本的百度搜索结果
selectors := []string{".content-right_8Zs40", ".c-abstract", ".content_LJ0WN", ".content"}
for _, selector := range selectors {
contentElement, err = result.Element(selector)
if err == nil && contentElement != nil {
content, _ = contentElement.Text()
if content != "" {
break
}
}
}
// 如果所有选择器都失败,尝试直接从结果块中提取文本
if content == "" {
// 获取结果元素的所有文本
fullText, err := result.Text()
if err == nil && fullText != "" {
// 简单处理:从全文中移除标题,剩下的可能是摘要
fullText = strings.Replace(fullText, title, "", 1)
// 清理文本
content = strings.TrimSpace(fullText)
// 限制内容长度
if len(content) > 200 {
content = content[:200] + "..."
}
}
}
// 添加到结果集
results = append(results, SearchResult{
Title: title,
URL: *href,
Content: content,
})
// 限制结果数量,每页最多 10 条
if len(results) >= 10*maxPages {
break
}
}
}
// 获取真实 URL百度搜索结果中的 URL 是短链接,需要跳转获取真实 URL
for i, result := range results {
realURL, err := s.getRedirectURL(result.URL)
if err == nil && realURL != "" {
results[i].URL = realURL
}
}
return results, nil
}
// 获取真实 URL
func (s *Service) getRedirectURL(shortURL string) (string, error) {
// 创建页面
page, err := s.browser.Page(proto.TargetCreateTarget{URL: ""})
if err != nil {
return shortURL, err // 返回原始URL
}
defer func() {
_ = page.Close()
}()
// 导航到短链接
err = page.Navigate(shortURL)
if err != nil {
return shortURL, err // 返回原始URL
}
// 等待重定向完成
time.Sleep(2 * time.Second)
// 获取当前 URL
info, err := page.Info()
if err != nil {
return shortURL, err // 返回原始URL
}
return info.URL, nil
}
// Close 关闭浏览器
func (s *Service) Close() error {
if s.browser != nil {
err := s.browser.Close()
s.browser = nil
return err
}
return nil
}
// SearchWeb 封装的搜索方法
func SearchWeb(keyword string, maxPages int) (string, error) {
// 添加panic恢复机制
defer func() {
if r := recover(); r != nil {
log := logger.GetLogger()
log.Errorf("爬虫服务崩溃: %v", r)
}
}()
service, err := NewService()
if err != nil {
return "", fmt.Errorf("创建爬虫服务失败: %v", err)
}
defer service.Close()
// 设置超时上下文
ctx, cancel := context.WithTimeout(context.Background(), 60*time.Second)
defer cancel()
// 使用goroutine和通道来处理超时
resultChan := make(chan []SearchResult, 1)
errChan := make(chan error, 1)
go func() {
results, err := service.WebSearch(keyword, maxPages)
if err != nil {
errChan <- err
return
}
resultChan <- results
}()
// 等待结果或超时
select {
case <-ctx.Done():
return "", fmt.Errorf("搜索超时: %v", ctx.Err())
case err := <-errChan:
return "", fmt.Errorf("搜索失败: %v", err)
case results := <-resultChan:
if len(results) == 0 {
return "未找到关于 \"" + keyword + "\" 的相关搜索结果", nil
}
// 格式化结果
var builder strings.Builder
builder.WriteString(fmt.Sprintf("为您找到关于 \"%s\" 的 %d 条搜索结果:\n\n", keyword, len(results)))
for i, result := range results {
// // 尝试打开链接获取实际内容
// page := service.browser.MustPage()
// defer page.MustClose()
// // 设置页面超时
// pageCtx, pageCancel := context.WithTimeout(context.Background(), 10*time.Second)
// defer pageCancel()
// // 导航到目标页面
// err := page.Context(pageCtx).Navigate(result.URL)
// if err == nil {
// // 等待页面加载
// _ = page.WaitLoad()
// // 获取页面标题
// title, err := page.Eval("() => document.title")
// if err == nil && title.Value.String() != "" {
// result.Title = title.Value.String()
// }
// // 获取页面主要内容
// if content, err := page.Element("body"); err == nil {
// if text, err := content.Text(); err == nil {
// // 清理并截取内容
// text = strings.TrimSpace(text)
// if len(text) > 200 {
// text = text[:200] + "..."
// }
// result.Prompt = text
// }
// }
// }
builder.WriteString(fmt.Sprintf("%d. **%s**\n", i+1, result.Title))
builder.WriteString(fmt.Sprintf(" 链接: %s\n", result.URL))
if result.Content != "" {
builder.WriteString(fmt.Sprintf(" 摘要: %s\n", result.Content))
}
builder.WriteString("\n")
}
return builder.String(), nil
}
}

View File

@@ -16,6 +16,7 @@ import (
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"time"
"github.com/go-redis/redis/v8"
@@ -94,12 +95,14 @@ func (s *Service) Run() {
}
type imgReq struct {
Model string `json:"model"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
Style string `json:"style,omitempty"`
Model string `json:"model"`
Image []string `json:"image,omitempty"`
Prompt string `json:"prompt"`
N int `json:"n,omitempty"`
Size string `json:"size,omitempty"`
Quality string `json:"quality,omitempty"`
Style string `json:"style,omitempty"`
ResponseFormat string `json:"response_format,omitempty"`
}
type imgRes struct {
@@ -122,15 +125,6 @@ type ErrRes struct {
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
logger.Debugf("绘画参数:%+v", task)
prompt := task.Prompt
// translate prompt
if utils.HasChinese(prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, prompt), task.TranslateModelId)
if err == nil {
prompt = content
logger.Debugf("重写后提示词:%s", prompt)
}
}
var chatModel model.ChatModel
if task.ModelId > 0 {
@@ -160,12 +154,17 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
apiURL := fmt.Sprintf("%s/v1/images/generations", apiKey.ApiURL)
reqBody := imgReq{
Model: chatModel.Value,
Prompt: prompt,
Prompt: task.Prompt,
N: 1,
Size: task.Size,
Style: task.Style,
Quality: task.Quality,
}
// 图片编辑
if len(task.Image) > 0 {
reqBody.Prompt = fmt.Sprintf("%s, %s", strings.Join(task.Image, " "), task.Prompt)
}
logger.Infof("Channel:%s, API KEY:%s, BODY: %+v", apiURL, apiKey.Value, reqBody)
r, err := s.httpClient.R().SetHeader("Body-Type", "application/json").
SetHeader("Authorization", "Bearer "+apiKey.Value).
@@ -188,7 +187,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
var imgURL string
var data = map[string]interface{}{
"progress": 100,
"prompt": prompt,
"prompt": task.Prompt,
}
// 如果返回的是base64则需要上传到oss
if res.Data[0].B64Json != "" {
@@ -210,11 +209,7 @@ func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
var content string
if sync {
imgURL, err := s.downloadImage(task.Id, res.Data[0].Url)
if err != nil {
return "", fmt.Errorf("error with download image: %v", err)
}
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", prompt, imgURL)
content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片\n\n![](%s)\n", task.Prompt, imgURL)
}
return content, nil

View File

@@ -3,8 +3,10 @@ package jimeng
import (
"encoding/json"
"fmt"
"geekai/core/types"
"net/http"
"net/url"
"strings"
"github.com/volcengine/volc-sdk-golang/base"
"github.com/volcengine/volc-sdk-golang/service/visual"
@@ -13,14 +15,22 @@ import (
// Client 即梦API客户端
type Client struct {
visual *visual.Visual
config types.JimengConfig
}
// NewClient 创建即梦API客户端
func NewClient(accessKey, secretKey string) *Client {
func NewClient(sysConfig *types.SystemConfig) *Client {
client := &Client{}
client.UpdateConfig(sysConfig.Jimeng)
return client
}
func (c *Client) UpdateConfig(config types.JimengConfig) error {
// 使用官方SDK的visual实例
visualInstance := visual.NewInstance()
visualInstance.Client.SetAccessKey(accessKey)
visualInstance.Client.SetSecretKey(secretKey)
visualInstance.Client.SetAccessKey(config.AccessKey)
visualInstance.Client.SetSecretKey(config.SecretKey)
// 添加即梦AI专有的API配置
jimengApis := map[string]*base.ApiInfo{
@@ -55,9 +65,32 @@ func NewClient(accessKey, secretKey string) *Client {
visualInstance.Client.ApiInfoList[name] = info
}
return &Client{
visual: visualInstance,
c.config = config
c.visual = visualInstance
return c.testConnection()
}
// testConnection 测试即梦AI连接
func (c *Client) testConnection() error {
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
}
_, err := c.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
}
// SubmitTask 提交异步任务

View File

@@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"strconv"
"strings"
"time"
"gorm.io/gorm"
@@ -16,8 +15,6 @@ import (
"geekai/store/model"
"geekai/utils"
"geekai/core/types"
"github.com/go-redis/redis/v8"
)
@@ -36,17 +33,8 @@ type Service struct {
}
// NewService 创建即梦服务
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager) *Service {
func NewService(db *gorm.DB, redisCli *redis.Client, uploader *oss.UploaderManager, client *Client) *Service {
taskQueue := store.NewRedisQueue("JimengTaskQueue", redisCli)
// 从数据库加载配置
var config model.Config
db.Where("name = ?", "Jimeng").First(&config)
var jimengConfig types.JimengConfig
if config.Id > 0 {
_ = utils.JsonDecode(config.Value, &jimengConfig)
}
client := NewClient(jimengConfig.AccessKey, jimengConfig.SecretKey)
ctx, cancel := context.WithCancel(context.Background())
return &Service{
db: db,
@@ -378,7 +366,7 @@ func (s *Service) pollTaskStatus() {
for _, job := range jobs {
// 任务超时处理
if job.UpdatedAt.Before(time.Now().Add(-5 * time.Minute)) {
if job.UpdatedAt.Before(time.Now().Add(-10 * time.Minute)) {
s.handleTaskError(job.Id, "task timeout")
continue
}
@@ -391,7 +379,7 @@ func (s *Service) pollTaskStatus() {
})
if err != nil {
logger.Errorf("query jimeng task status failed: %v", err)
s.handleTaskError(job.Id, fmt.Sprintf("query task failed: %s", err.Error()))
continue
}
@@ -446,9 +434,7 @@ func (s *Service) pollTaskStatus() {
s.handleTaskError(job.Id, "task not found")
case model.JMTaskStatusExpired:
// 任务过期
s.handleTaskError(job.Id, "task expired")
continue
default:
logger.Warnf("unknown task status: %s", resp.Data.Status)
}
@@ -524,77 +510,3 @@ func (s *Service) GetJob(jobId uint) (*model.JimengJob, error) {
}
return &job, nil
}
// testConnection 测试即梦AI连接
func (s *Service) testConnection(accessKey, secretKey string) error {
testClient := NewClient(accessKey, secretKey)
// 使用一个简单的查询任务来测试连接
testReq := &QueryTaskRequest{
ReqKey: "test_connection",
TaskId: "test_task_id_12345",
}
_, err := testClient.QueryTask(testReq)
// 即使任务不存在,只要不是认证错误就说明连接正常
if err != nil {
// 检查是否是认证错误
if strings.Contains(err.Error(), "InvalidAccessKey") {
return fmt.Errorf("认证失败请检查AccessKey和SecretKey是否正确")
}
// 其他错误(如任务不存在)说明连接正常
return nil
}
return nil
}
// UpdateClientConfig 更新客户端配置
func (s *Service) UpdateClientConfig(accessKey, secretKey string) error {
// 创建新的客户端
newClient := NewClient(accessKey, secretKey)
// 测试新客户端是否可用
err := s.testConnection(accessKey, secretKey)
if err != nil {
return err
}
// 更新客户端
s.client = newClient
return nil
}
var defaultPower = types.JimengPower{
TextToImage: 20,
ImageToImage: 20,
ImageEdit: 20,
ImageEffects: 20,
TextToVideo: 300,
ImageToVideo: 300,
}
// GetConfig 获取即梦AI配置
func (s *Service) GetConfig() *types.JimengConfig {
var config model.Config
err := s.db.Where("name", "jimeng").First(&config).Error
if err != nil {
// 如果配置不存在,返回默认配置
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
var jimengConfig types.JimengConfig
err = utils.JsonDecode(config.Value, &jimengConfig)
if err != nil {
return &types.JimengConfig{
AccessKey: "",
SecretKey: "",
Power: defaultPower,
}
}
return &jimengConfig
}

View File

@@ -8,30 +8,37 @@ package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"fmt"
"geekai/core"
"geekai/core/types"
"geekai/store"
"geekai/store/model"
"geekai/utils"
"strings"
"time"
"github.com/imroc/req/v3"
"github.com/shirou/gopsutil/host"
"gorm.io/gorm"
)
type LicenseService struct {
config types.ApiConfig
levelDB *store.LevelDB
license *types.License
urlWhiteList []string
machineId string
db *gorm.DB
}
func NewLicenseService(server *core.AppServer, levelDB *store.LevelDB) *LicenseService {
var license types.License
func NewLicenseService(sysConfig *types.SystemConfig, db *gorm.DB) *LicenseService {
var machineId string
info, err := host.Info()
if err == nil {
machineId = info.HostID
}
logger.Infof("License: %+v", sysConfig.License)
return &LicenseService{
config: server.Config.ApiConfig,
levelDB: levelDB,
license: &license,
machineId: "",
license: &sysConfig.License,
machineId: machineId,
db: db,
}
}
@@ -46,15 +53,15 @@ type License struct {
}
// ActiveLicense 激活 License
func (s *LicenseService) ActiveLicense(license string, machineId string) error {
func (s *LicenseService) ActiveLicense(license string) error {
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/active")
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/active")
response, err := req.C().R().
SetBody(map[string]string{"license": license, "machine_id": machineId}).
SetBody(map[string]string{"license": license, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return fmt.Errorf("发送激活请求失败: %v", err)
@@ -68,17 +75,24 @@ func (s *LicenseService) ActiveLicense(license string, machineId string) error {
return fmt.Errorf("激活失败:%v", res.Message)
}
if res.Data.ExpiredAt > 0 && res.Data.ExpiredAt < time.Now().Unix() {
return fmt.Errorf("License 已过期")
}
s.license = &types.License{
Key: license,
MachineId: machineId,
MachineId: s.machineId,
Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}
err = s.levelDB.Put(types.LicenseKey, s.license)
// 保存 License 到数据库
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
if err != nil {
return fmt.Errorf("保存许可证书失败:%v", err)
return fmt.Errorf("保存 License 到数据库失败: %v", err)
}
return nil
}
@@ -96,6 +110,11 @@ func (s *LicenseService) SyncLicense() {
s.license.IsActive = false
} else {
s.license = license
// 保存 License 到数据库
err = s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).UpdateColumn("value", utils.JsonEncode(s.license)).Error
if err != nil {
logger.Errorf("保存 License 到数据库失败: %v", err)
}
}
urls, err := s.fetchUrlWhiteList()
@@ -109,33 +128,30 @@ func (s *LicenseService) SyncLicense() {
}
func (s *LicenseService) fetchLicense() (*types.License, error) {
//var res struct {
// Code types.BizCode `json:"code"`
// Message string `json:"message"`
// Data License `json:"data"`
//}
//apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/check")
//response, err := req.C().R().
// SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
// SetSuccessResult(&res).Post(apiURL)
//if err != nil {
// return nil, fmt.Errorf("发送激活请求失败: %v", err)
//}
//if response.IsErrorState() {
// return nil, fmt.Errorf("激活失败:%v", response.Status)
//}
//if res.Code != types.Success {
// return nil, fmt.Errorf("激活失败:%v", res.Message)
//}
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data License `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/check")
response, err := req.C().R().
SetBody(map[string]string{"license": s.license.Key, "machine_id": s.machineId}).
SetSuccessResult(&res).Post(apiURL)
if err != nil {
return nil, fmt.Errorf("License 同步失败: %v", err)
}
if response.IsErrorState() {
return nil, fmt.Errorf("License 同步失败:%v", response.Status)
}
if res.Code != types.Success {
return nil, fmt.Errorf("License 同步失败:%v", res.Message)
}
return &types.License{
Key: "abc",
MachineId: "abc",
Configs: types.LicenseConfig{
UserNum: 10000,
DeCopy: false,
},
ExpiredAt: 0,
Key: res.Data.License,
MachineId: res.Data.MachineId,
Configs: res.Data.Configs,
ExpiredAt: res.Data.ExpiredAt,
IsActive: true,
}, nil
}
@@ -146,7 +162,7 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
Message string `json:"message"`
Data []string `json:"data"`
}
apiURL := fmt.Sprintf("%s/%s", s.config.ApiURL, "api/license/urls")
apiURL := fmt.Sprintf("%s/%s", types.GeekAPIURL, "api/license/urls")
response, err := req.C().R().SetSuccessResult(&res).Get(apiURL)
if err != nil {
return nil, fmt.Errorf("发送请求失败: %v", err)
@@ -163,35 +179,46 @@ func (s *LicenseService) fetchUrlWhiteList() ([]string, error) {
// GetLicense 获取许可信息
func (s *LicenseService) GetLicense() *types.License {
if s.license == nil {
var config model.Config
s.db.Model(&model.Config{}).Where("name = ?", types.ConfigKeyLicense).First(&config)
if config.Value != "" {
utils.JsonDecode(config.Value, &s.license)
}
}
return s.license
}
func (s *LicenseService) SetLicense(licenseKey string) {
s.license.Key = licenseKey
}
// IsValidApiURL 判断是否合法的中转 URL
func (s *LicenseService) IsValidApiURL(uri string) error {
// 获得许可授权的直接放行
return nil
//if s.license.IsActive {
// if s.license.MachineId != s.machineId {
// return errors.New("系统使用了盗版的许可证书")
// }
//
// if time.Now().Unix() > s.license.ExpiredAt {
// return errors.New("系统许可证书已经过期")
// }
// return nil
//}
//
//if len(s.urlWhiteList) == 0 {
// urls, err := s.fetchUrlWhiteList()
// if err == nil {
// s.urlWhiteList = urls
// }
//}
//
//for _, v := range s.urlWhiteList {
// if strings.HasPrefix(uri, v) {
// return nil
// }
//}
//return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
if s.license.IsActive {
if s.license.MachineId != s.machineId {
return errors.New("系统使用了盗版的许可证书")
}
if time.Now().Unix() > s.license.ExpiredAt {
return errors.New("系统许可证书已经过期")
}
return nil
}
if len(s.urlWhiteList) == 0 {
urls, err := s.fetchUrlWhiteList()
if err == nil {
s.urlWhiteList = urls
}
}
for _, v := range s.urlWhiteList {
if strings.HasPrefix(uri, v) {
return nil
}
}
return fmt.Errorf("当前 API 地址 %s 不在白名单列表当中。", uri)
}

View File

@@ -1,52 +1,342 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// Copyright 2023 The Geek-AI Authors. All rights reserved.
// Use of this source code is governed by a Apache-2.0 license
// that can be found in the LICENSE file.
// @Author yangjian102621@163.com
// ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"encoding/json"
"fmt"
"geekai/core/types"
"geekai/store"
"geekai/store/model"
"strings"
"github.com/go-redis/redis/v8"
"gorm.io/gorm"
)
const (
// 迁移状态Redis key
MigrationStatusKey = "config_migration:status"
// 迁移完成标志
MigrationCompleted = "completed"
)
// MigrationService 配置迁移服务
type MigrationService struct {
db *gorm.DB
db *gorm.DB
redisClient *redis.Client
appConfig *types.AppConfig
levelDB *store.LevelDB
licenseService *LicenseService
}
func NewMigrationService(db *gorm.DB) *MigrationService {
return &MigrationService{db: db}
func NewMigrationService(db *gorm.DB, redisClient *redis.Client, appConfig *types.AppConfig, levelDB *store.LevelDB, licenseService *LicenseService) *MigrationService {
return &MigrationService{
db: db,
redisClient: redisClient,
appConfig: appConfig,
levelDB: levelDB,
licenseService: licenseService,
}
}
func (s *MigrationService) Migrate() error {
err := s.db.AutoMigrate(
&model.AdminUser{},
&model.ApiKey{},
&model.AppType{},
&model.ChatItem{},
&model.ChatMessage{},
&model.ChatModel{},
&model.ChatRole{},
&model.Config{},
&model.DallJob{},
&model.File{},
&model.Function{},
&model.InviteCode{},
&model.InviteLog{},
&model.Menu{},
&model.MidJourneyJob{},
&model.Order{},
&model.PowerLog{},
&model.Product{},
&model.Redeem{},
&model.SdJob{},
&model.SunoJob{},
&model.User{},
&model.UserLoginLog{},
&model.VideoJob{},
)
return err
func (s *MigrationService) StartMigrate() {
go func() {
s.MigrateConfig(s.appConfig)
s.TableMigration()
s.MigrateLicense()
}()
}
// 迁移 License
func (s *MigrationService) MigrateLicense() {
key := "migrate:license"
if s.redisClient.Get(context.Background(), key).Val() == "1" {
logger.Info("License 已迁移,跳过迁移")
return
}
logger.Info("开始迁移 License...")
var license types.License
err := s.levelDB.Get(types.LicenseKey, &license)
if err != nil {
license = types.License{
Key: "",
MachineId: "",
Configs: types.LicenseConfig{UserNum: 0, DeCopy: false},
ExpiredAt: 0,
IsActive: false,
}
}
logger.Infof("迁移 License: %+v", license)
if err := s.saveConfig(types.ConfigKeyLicense, license); err != nil {
logger.Errorf("迁移 License 失败: %v", err)
return
}
s.licenseService.SetLicense(license.Key)
logger.Info("迁移 License 完成")
s.redisClient.Set(context.Background(), key, "1", 0)
}
// 迁移配置内容
func (s *MigrationService) MigrateConfigContent() error {
// 用户协议
if err := s.saveConfig(types.ConfigKeyPrivacy, map[string]string{
"content": "用户协议内容",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 隐私政策
if err := s.saveConfig(types.ConfigKeyAgreement, map[string]string{
"content": "隐私政策内容",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 思维导图
if err := s.saveConfig(types.ConfigKeyMarkMap, map[string]string{
"content": `# GeekAI 演示站
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
- 基于 Websocket 实现,完美的打字机体验。
- 内置了各种预训练好的角色应用,轻松满足你的各种聊天和应用需求。
- 支持 OPenAIAzure文心一言讯飞星火清华 ChatGLM等多个大语言模型。
- 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件。`,
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 微信登录配置
if err := s.saveConfig(types.ConfigKeyWxLogin, map[string]string{
"api_key": "",
"notify_url": "",
"enabled": "false",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 验证码配置
if err := s.saveConfig(types.ConfigKeyCaptcha, map[string]string{
"api_key": "",
"type": "dot",
"enabled": "false",
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
// 文本审核
if err := s.saveConfig(types.ConfigKeyModeration, map[string]any{
"enable": "false",
"active": "gitee",
"enable_guide": "false",
"guide_prompt": "",
"gitee": map[string]string{
"api_key": "",
"model": "Security-semantic-filtering",
},
"baidu": map[string]string{
"access_key": "",
"secret_key": "",
},
"tencent": map[string]string{
"access_key": "",
"secret_key": "",
},
}); err != nil {
return fmt.Errorf("迁移配置内容失败: %v", err)
}
return nil
}
// 数据表迁移
func (s *MigrationService) TableMigration() {
// 新数据表
s.db.AutoMigrate(&model.Moderation{})
// 订单字段整理
if s.db.Migrator().HasColumn(&model.Order{}, "pay_type") {
s.db.Migrator().RenameColumn(&model.Order{}, "pay_type", "channel")
}
if !s.db.Migrator().HasColumn(&model.Order{}, "checked") {
s.db.Migrator().AddColumn(&model.Order{}, "checked")
}
// 重命名 config 表字段
if s.db.Migrator().HasColumn(&model.Config{}, "config_json") {
s.db.Migrator().RenameColumn(&model.Config{}, "config_json", "value")
}
if s.db.Migrator().HasColumn(&model.Config{}, "marker") {
s.db.Migrator().RenameColumn(&model.Config{}, "marker", "name")
}
if s.db.Migrator().HasIndex(&model.Config{}, "idx_chatgpt_configs_key") {
s.db.Migrator().DropIndex(&model.Config{}, "idx_chatgpt_configs_key")
}
if s.db.Migrator().HasIndex(&model.Config{}, "marker") {
s.db.Migrator().DropIndex(&model.Config{}, "marker")
}
// 手动删除字段
if s.db.Migrator().HasColumn(&model.Order{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.Order{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatItem{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatItem{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.ChatMessage{}, "deleted_at") {
s.db.Migrator().DropColumn(&model.ChatMessage{}, "deleted_at")
}
if s.db.Migrator().HasColumn(&model.User{}, "chat_config") {
s.db.Migrator().DropColumn(&model.User{}, "chat_config")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "category") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "category")
}
if s.db.Migrator().HasColumn(&model.ChatModel{}, "description") {
s.db.Migrator().DropColumn(&model.ChatModel{}, "description")
}
if s.db.Migrator().HasColumn(&model.Product{}, "discount") {
s.db.Migrator().DropColumn(&model.Product{}, "discount")
}
if s.db.Migrator().HasColumn(&model.Product{}, "days") {
s.db.Migrator().DropColumn(&model.Product{}, "days")
}
if s.db.Migrator().HasColumn(&model.Product{}, "app_url") {
s.db.Migrator().DropColumn(&model.Product{}, "app_url")
}
if s.db.Migrator().HasColumn(&model.Product{}, "url") {
s.db.Migrator().DropColumn(&model.Product{}, "url")
}
}
// 迁移配置数据
func (s *MigrationService) MigrateConfig(config *types.AppConfig) error {
logger.Info("开始迁移配置到数据库...")
// 迁移支付配置
if err := s.migratePaymentConfig(config); err != nil {
logger.Errorf("迁移支付配置失败: %v", err)
return err
}
// 迁移存储配置
if err := s.migrateStorageConfig(config); err != nil {
logger.Errorf("迁移存储配置失败: %v", err)
return err
}
// 迁移通信配置
if err := s.migrateCommunicationConfig(config); err != nil {
logger.Errorf("迁移通信配置失败: %v", err)
return err
}
// 迁移配置内容
if err := s.MigrateConfigContent(); err != nil {
logger.Errorf("迁移配置内容失败: %v", err)
return err
}
logger.Info("配置迁移完成")
return nil
}
// 迁移支付配置
func (s *MigrationService) migratePaymentConfig(config *types.AppConfig) error {
paymentConfig := types.PaymentConfig{
Alipay: config.AlipayConfig,
Epay: config.GeekPayConfig,
WxPay: config.WechatPayConfig,
}
if err := s.saveConfig(types.ConfigKeyPayment, paymentConfig); err != nil {
return err
}
return nil
}
// 迁移存储配置
func (s *MigrationService) migrateStorageConfig(config *types.AppConfig) error {
ossConfig := types.OSSConfig{
Active: config.OSS.Active,
Local: config.OSS.Local,
Minio: config.OSS.Minio,
QiNiu: config.OSS.QiNiu,
AliYun: config.OSS.AliYun,
}
return s.saveConfig(types.ConfigKeyOss, ossConfig)
}
// 迁移通信配置
func (s *MigrationService) migrateCommunicationConfig(config *types.AppConfig) error {
// SMTP配置
smtpConfig := map[string]any{
"use_tls": config.SmtpConfig.UseTls,
"host": config.SmtpConfig.Host,
"port": config.SmtpConfig.Port,
"app_name": config.SmtpConfig.AppName,
"from": config.SmtpConfig.From,
"password": config.SmtpConfig.Password,
}
if err := s.saveConfig(types.ConfigKeySmtp, smtpConfig); err != nil {
return err
}
// 短信配置
smsConfig := map[string]any{
"active": strings.ToLower(config.SMS.Active),
"aliyun": map[string]any{
"access_key": config.SMS.Ali.AccessKey,
"access_secret": config.SMS.Ali.AccessSecret,
"sign": config.SMS.Ali.Sign,
"code_temp_id": config.SMS.Ali.CodeTempId,
},
"bao": map[string]any{
"username": config.SMS.Bao.Username,
"password": config.SMS.Bao.Password,
"sign": config.SMS.Bao.Sign,
"code_template": config.SMS.Bao.CodeTemplate,
},
}
return s.saveConfig(types.ConfigKeySms, smsConfig)
}
// 保存配置到数据库
func (s *MigrationService) saveConfig(key string, config any) error {
// 检查是否已存在
var existingConfig model.Config
if err := s.db.Where("name", key).First(&existingConfig).Error; err == nil {
// 配置已存在,跳过
logger.Infof("配置 %s 已存在,跳过迁移", key)
return nil
}
// 序列化配置
configJSON, err := json.Marshal(config)
if err != nil {
return err
}
// 保存到数据库
newConfig := model.Config{
Name: key,
Value: string(configJSON),
}
if err := s.db.Create(&newConfig).Error; err != nil {
return err
}
logger.Infof("成功迁移配置 %s", key)
return nil
}

View File

@@ -67,25 +67,6 @@ func (s *Service) Run() {
continue
}
// translate prompt
if utils.HasChinese(task.Prompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.Prompt), task.TranslateModelId)
if err == nil {
task.Prompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// translate negative prompt
if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.TranslatePromptTemplate, task.NegPrompt), task.TranslateModelId)
if err == nil {
task.NegPrompt = content
} else {
logger.Warnf("error with translate prompt: %v", err)
}
}
// use fast mode as default
if task.Mode == "" {
task.Mode = "fast"

View File

@@ -0,0 +1,33 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
)
type BaiduAIModeration struct {
config types.ModerationBaiduConfig
}
func NewBaiduAIModeration(sysConfig *types.SystemConfig) *BaiduAIModeration {
return &BaiduAIModeration{
config: sysConfig.Moderation.Baidu,
}
}
func (s *BaiduAIModeration) UpdateConfig(config types.ModerationBaiduConfig) {
s.config = config
}
func (s *BaiduAIModeration) Moderate(text string) (types.ModerationResult, error) {
return types.ModerationResult{}, errors.New("not implemented")
}
var _ Service = (*BaiduAIModeration)(nil)

View File

@@ -0,0 +1,58 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
"github.com/imroc/req/v3"
)
type GiteeAIModeration struct {
config types.ModerationGiteeConfig
apiURL string
}
func NewGiteeAIModeration(sysConfig *types.SystemConfig) *GiteeAIModeration {
return &GiteeAIModeration{
config: sysConfig.Moderation.Gitee,
apiURL: "https://ai.gitee.com/v1/moderations",
}
}
func (s *GiteeAIModeration) UpdateConfig(config types.ModerationGiteeConfig) {
s.config = config
}
type GiteeAIModerationResult struct {
ID string `json:"id"`
Model string `json:"model"`
Results []types.ModerationResult `json:"results"`
}
func (s *GiteeAIModeration) Moderate(text string) (types.ModerationResult, error) {
body := map[string]any{
"input": text,
"model": s.config.Model,
}
var res GiteeAIModerationResult
r, err := req.C().R().SetHeader("Authorization", "Bearer "+s.config.ApiKey).SetBody(body).SetSuccessResult(&res).Post(s.apiURL)
if err != nil {
return types.ModerationResult{}, err
}
if r.IsErrorState() {
return types.ModerationResult{}, errors.New(r.String())
}
return res.Results[0], nil
}
var _ Service = (*GiteeAIModeration)(nil)

View File

@@ -0,0 +1,58 @@
package moderation
import (
"geekai/core/types"
logger2 "geekai/logger"
)
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
var logger = logger2.GetLogger()
type Service interface {
Moderate(text string) (types.ModerationResult, error)
}
type ServiceManager struct {
gitee *GiteeAIModeration
baidu *BaiduAIModeration
tencent *TencentAIModeration
active string
}
func NewServiceManager(gitee *GiteeAIModeration, baidu *BaiduAIModeration, tencent *TencentAIModeration) *ServiceManager {
return &ServiceManager{
gitee: gitee,
baidu: baidu,
tencent: tencent,
}
}
func (s *ServiceManager) GetService() Service {
switch s.active {
case types.ModerationBaidu:
return s.baidu
case types.ModerationTencent:
return s.tencent
default:
return s.gitee
}
}
func (s *ServiceManager) UpdateConfig(config types.ModerationConfig) {
switch config.Active {
case types.ModerationGitee:
s.gitee.UpdateConfig(config.Gitee)
case types.ModerationBaidu:
s.baidu.UpdateConfig(config.Baidu)
case types.ModerationTencent:
s.tencent.UpdateConfig(config.Tencent)
}
s.active = config.Active
}

View File

@@ -0,0 +1,33 @@
package moderation
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"errors"
"geekai/core/types"
)
type TencentAIModeration struct {
config types.ModerationTencentConfig
}
func NewTencentAIModeration(sysConfig *types.SystemConfig) *TencentAIModeration {
return &TencentAIModeration{
config: sysConfig.Moderation.Tencent,
}
}
func (s *TencentAIModeration) UpdateConfig(config types.ModerationTencentConfig) {
s.config = config
}
func (s *TencentAIModeration) Moderate(text string) (types.ModerationResult, error) {
return types.ModerationResult{}, errors.New("not implemented")
}
var _ Service = (*TencentAIModeration)(nil)

View File

@@ -23,35 +23,35 @@ import (
)
type AliYunOss struct {
config *types.AliYunOssConfig
config types.AliYunOssConfig
bucket *oss.Bucket
proxyURL string
}
func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
config := &appConfig.OSS.AliYun
// 创建 OSS 客户端
func NewAliYunOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*AliYunOss, error) {
s := &AliYunOss{
proxyURL: appConfig.ProxyURL,
}
err := s.UpdateConfig(sysConfig.OSS.AliYun)
if err != nil {
logger.Warnf("阿里云OSS初始化失败: %v", err)
}
return s, nil
}
func (s *AliYunOss) UpdateConfig(config types.AliYunOssConfig) error {
client, err := oss.New(config.Endpoint, config.AccessKey, config.AccessSecret)
if err != nil {
return nil, err
return err
}
// 获取存储空间
bucket, err := client.Bucket(config.Bucket)
if err != nil {
return nil, err
return err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return &AliYunOss{
config: config,
bucket: bucket,
proxyURL: appConfig.ProxyURL,
}, nil
s.bucket = bucket
s.config = config
return nil
}
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
@@ -68,7 +68,7 @@ func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer src.Close()
fileExt := filepath.Ext(file.Filename)
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
// 上传文件
err = s.bucket.PutObject(objectKey, src)
if err != nil {
@@ -102,7 +102,7 @@ func (s AliYunOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
if ext == "" {
ext = filepath.Ext(parse.Path)
}
objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(fileData))
if err != nil {
@@ -116,7 +116,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
// 上传文件字节数据
err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
if err != nil {
@@ -128,8 +128,7 @@ func (s AliYunOss) PutBase64(base64Img string) (string, error) {
func (s AliYunOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
objectKey = filepath.Base(fileURL)
} else {
objectKey = fileURL
}

View File

@@ -21,17 +21,21 @@ import (
)
type LocalStorage struct {
config *types.LocalStorageConfig
config types.LocalStorageConfig
proxyURL string
}
func NewLocalStorage(config *types.AppConfig) LocalStorage {
return LocalStorage{
config: &config.OSS.Local,
proxyURL: config.ProxyURL,
func NewLocalStorage(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *LocalStorage {
return &LocalStorage{
config: sysConfig.OSS.Local,
proxyURL: appConfig.ProxyURL,
}
}
func (s *LocalStorage) UpdateConfig(config types.LocalStorageConfig) {
s.config = config
}
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
file, err := ctx.FormFile(name)
if err != nil {

View File

@@ -24,24 +24,32 @@ import (
)
type MiniOss struct {
config *types.MiniOssConfig
config types.MiniOssConfig
client *minio.Client
proxyURL string
}
func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
config := &appConfig.OSS.Minio
func NewMiniOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) (*MiniOss, error) {
s := &MiniOss{proxyURL: appConfig.ProxyURL}
err := s.UpdateConfig(sysConfig.OSS.Minio)
if err != nil {
logger.Warnf("MinioOSS初始化失败: %v", err)
}
return s, nil
}
func (s *MiniOss) UpdateConfig(config types.MiniOssConfig) error {
minioClient, err := minio.New(config.Endpoint, &minio.Options{
Creds: credentials.NewStaticV4(config.AccessKey, config.AccessSecret, ""),
Secure: config.UseSSL,
})
if err != nil {
return MiniOss{}, err
return err
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
s.config = config
s.client = minioClient
return nil
}
func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
@@ -62,7 +70,7 @@ func (s MiniOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string,
if ext == "" {
ext = filepath.Ext(parse.Path)
}
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
@@ -89,7 +97,7 @@ func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer fileReader.Close()
fileExt := filepath.Ext(file.Filename)
filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
ContentType: file.Header.Get("Body-Type"),
})
@@ -111,7 +119,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
info, err := s.client.PutObject(
context.Background(),
s.config.Bucket,
@@ -128,8 +136,7 @@ func (s MiniOss) PutBase64(base64Img string) (string, error) {
func (s MiniOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
objectKey = filepath.Base(fileURL)
} else {
objectKey = fileURL
}

View File

@@ -24,18 +24,24 @@ import (
"github.com/qiniu/go-sdk/v7/storage"
)
type QinNiuOss struct {
config *types.QiNiuOssConfig
type QiNiuOss struct {
config types.QiNiuOssConfig
mac *qbox.Mac
putPolicy storage.PutPolicy
uploader *storage.FormUploader
manager *storage.BucketManager
bucket *storage.BucketManager
proxyURL string
}
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
config := &appConfig.OSS.QiNiu
// build storage uploader
func NewQiNiuOss(sysConfig *types.SystemConfig, appConfig *types.AppConfig) *QiNiuOss {
s := &QiNiuOss{
proxyURL: appConfig.ProxyURL,
}
s.UpdateConfig(sysConfig.OSS.QiNiu)
return s
}
func (s *QiNiuOss) UpdateConfig(config types.QiNiuOssConfig) {
zone, ok := storage.GetRegionByID(storage.RegionID(config.Zone))
if !ok {
zone = storage.ZoneHuanan
@@ -47,20 +53,13 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
putPolicy := storage.PutPolicy{
Scope: config.Bucket,
}
if config.SubDir == "" {
config.SubDir = "gpt"
}
return QinNiuOss{
config: config,
mac: mac,
putPolicy: putPolicy,
uploader: formUploader,
manager: storage.NewBucketManager(mac, &storeConfig),
proxyURL: appConfig.ProxyURL,
}
s.config = config
s.mac = mac
s.putPolicy = putPolicy
s.uploader = formUploader
s.bucket = storage.NewBucketManager(mac, &storeConfig)
}
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
func (s QiNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
// 解析表单
file, err := ctx.FormFile(name)
if err != nil {
@@ -74,7 +73,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
defer src.Close()
fileExt := filepath.Ext(file.Filename)
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
// 上传文件
ret := storage.PutRet{}
extra := storage.PutExtra{}
@@ -93,7 +92,7 @@ func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
}
func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
func (s QiNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string, error) {
var fileData []byte
var err error
if useProxy {
@@ -111,7 +110,7 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
if ext == "" {
ext = filepath.Ext(parse.Path)
}
key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), ext)
key := fmt.Sprintf("%d%s", time.Now().UnixMicro(), ext)
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
@@ -122,12 +121,12 @@ func (s QinNiuOss) PutUrlFile(fileURL string, ext string, useProxy bool) (string
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
func (s QiNiuOss) PutBase64(base64Img string) (string, error) {
imageData, err := base64.StdEncoding.DecodeString(base64Img)
if err != nil {
return "", fmt.Errorf("error decoding base64:%v", err)
}
objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
objectKey := fmt.Sprintf("%d.png", time.Now().UnixMicro())
ret := storage.PutRet{}
extra := storage.PutExtra{}
// 上传文件字节数据
@@ -138,16 +137,15 @@ func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
}
func (s QinNiuOss) Delete(fileURL string) error {
func (s QiNiuOss) Delete(fileURL string) error {
var objectKey string
if strings.HasPrefix(fileURL, "http") {
filename := filepath.Base(fileURL)
objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
objectKey = filepath.Base(fileURL)
} else {
objectKey = fileURL
}
return s.manager.Delete(s.config.Bucket, objectKey)
return s.bucket.Delete(s.config.Bucket, objectKey)
}
var _ Uploader = QinNiuOss{}
var _ Uploader = QiNiuOss{}

View File

@@ -9,10 +9,10 @@ package oss
import "github.com/gin-gonic/gin"
const Local = "LOCAL"
const Minio = "MINIO"
const QiNiu = "QINIU"
const AliYun = "ALIYUN"
const Local = "local"
const Minio = "minio"
const QiNiu = "qiniu"
const AliYun = "aliyun"
type File struct {
Name string `json:"name"`

View File

@@ -9,45 +9,58 @@ package oss
import (
"geekai/core/types"
"strings"
logger2 "geekai/logger"
)
var logger = logger2.GetLogger()
type UploaderManager struct {
handler Uploader
local *LocalStorage
aliyun *AliYunOss
mini *MiniOss
qiniu *QiNiuOss
active string
}
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
active := Local
if config.OSS.Active != "" {
active = strings.ToUpper(config.OSS.Active)
}
var handler Uploader
switch active {
case Local:
handler = NewLocalStorage(config)
break
case Minio:
client, err := NewMiniOss(config)
if err != nil {
return nil, err
}
handler = client
break
case QiNiu:
handler = NewQiNiuOss(config)
break
case AliYun:
client, err := NewAliYunOss(config)
if err != nil {
return nil, err
}
handler = client
break
func NewUploaderManager(sysConfig *types.SystemConfig, local *LocalStorage, aliyun *AliYunOss, mini *MiniOss, qiniu *QiNiuOss) (*UploaderManager, error) {
if sysConfig.OSS.Active == "" {
sysConfig.OSS.Active = Local
}
return &UploaderManager{handler: handler}, nil
return &UploaderManager{
active: sysConfig.OSS.Active,
local: local,
aliyun: aliyun,
mini: mini,
qiniu: qiniu,
}, nil
}
func (m *UploaderManager) GetUploadHandler() Uploader {
return m.handler
switch m.active {
case Local:
return m.local
case AliYun:
return m.aliyun
case Minio:
return m.mini
case QiNiu:
return m.qiniu
}
return m.local
}
func (m *UploaderManager) UpdateConfig(config types.OSSConfig) {
switch config.Active {
case Local:
m.local.UpdateConfig(config.Local)
case AliYun:
m.aliyun.UpdateConfig(config.AliYun)
case Minio:
m.mini.UpdateConfig(config.Minio)
case QiNiu:
m.qiniu.UpdateConfig(config.QiNiu)
}
m.active = config.Active
}

View File

@@ -12,129 +12,98 @@ import (
"fmt"
"geekai/core/types"
logger2 "geekai/logger"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
"net/http"
"os"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/alipay"
)
type AlipayService struct {
config *types.AlipayConfig
client *alipay.Client
config *types.AlipayConfig
}
var logger = logger2.GetLogger()
func NewAlipayService(appConfig *types.AppConfig) (*AlipayService, error) {
config := appConfig.AlipayConfig
func NewAlipayService(sysConfig *types.SystemConfig) (*AlipayService, error) {
config := sysConfig.Payment.Alipay
if !config.Enabled {
logger.Info("Disabled Alipay service")
return nil, nil
logger.Debug("Disabled Alipay service")
}
priKey, err := readKey(config.PrivateKey)
service := &AlipayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("支付宝服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *AlipayService) UpdateConfig(config *types.AlipayConfig) error {
client, err := alipay.NewClient(config.AppId, config.PrivateKey, !config.SandBox)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
return fmt.Errorf("error with initialize alipay service: %v", err)
}
client, err := alipay.NewClient(config.AppId, priKey, !config.SandBox)
if err != nil {
return nil, fmt.Errorf("error with initialize alipay service: %v", err)
s.client = client
s.config = config
if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("Alipay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
}
//client.DebugSwitch = gopay.DebugOn // 开启调试模式
client.SetLocation(alipay.LocationShanghai). // 设置时区,不设置或出错均为默认服务器时间
SetCharset(alipay.UTF8). // 设置字符编码,不设置默认 utf-8
SetSignType(alipay.RSA2) // 设置签名类型,不设置默认 RSA2
if err = client.SetCertSnByPath(config.PublicKey, config.RootCert, config.AlipayPublicKey); err != nil {
return nil, fmt.Errorf("error with load payment public key: %v", err)
}
return &AlipayService{config: &config, client: client}, nil
return nil
}
type AlipayParams struct {
OutTradeNo string `json:"out_trade_no"`
Subject string `json:"subject"`
TotalFee string `json:"total_fee"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *AlipayService) PayMobile(params AlipayParams) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("quit_url", params.ReturnURL)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "QUICK_WAP_WAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) PayPC(params AlipayParams) (string, error) {
func (s *AlipayService) Pay(params PayRequest) (string, error) {
bm := make(gopay.BodyMap)
bm.Set("subject", params.Subject)
bm.Set("out_trade_no", params.OutTradeNo)
bm.Set("total_amount", params.TotalFee)
bm.Set("product_code", "FAST_INSTANT_TRADE_PAY")
return s.client.SetNotifyUrl(params.NotifyURL).SetReturnUrl(params.ReturnURL).TradePagePay(context.Background(), bm)
return s.client.TradeWapPay(context.Background(), bm)
}
func (s *AlipayService) Query(outTradeNo string) (OrderInfo, error) {
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with trade query: %v", err)
}
switch rsp.Response.TradeStatus {
case "TRADE_SUCCESS":
logger.Debugf("支付宝查询订单成功:%+v", rsp.Response)
return OrderInfo{
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Status: Success,
PayTime: rsp.Response.SendPayDate,
}, nil
case "TRADE_CLOSED":
return OrderInfo{Status: Closed}, nil
default:
return OrderInfo{}, fmt.Errorf("error with trade query: %v", rsp.Response.TradeStatus)
}
}
// TradeVerify 交易验证
func (s *AlipayService) TradeVerify(request *http.Request) NotifyVo {
func (s *AlipayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := alipay.ParseNotifyToBodyMap(request) // c.Request 是 gin 框架的写法
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with parse notify request: " + err.Error(),
}
return OrderInfo{}, fmt.Errorf("error with parse notify request: %v", err)
}
_, err = alipay.VerifySignWithCert(s.config.AlipayPublicKey, notifyReq)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "error with verify sign: " + err.Error(),
}
return OrderInfo{}, fmt.Errorf("error with verify sign: %v", err)
}
return s.TradeQuery(request.Form.Get("out_trade_no"))
return s.Query(request.Form.Get("out_trade_no"))
}
func (s *AlipayService) TradeQuery(outTradeNo string) NotifyVo {
bm := make(gopay.BodyMap)
bm.Set("out_trade_no", outTradeNo)
//查询订单
rsp, err := s.client.TradeQuery(context.Background(), bm)
if err != nil {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo + err.Error(),
}
}
if rsp.Response.TradeStatus == "TRADE_SUCCESS" {
return NotifyVo{
Status: Success,
OutTradeNo: rsp.Response.OutTradeNo,
TradeId: rsp.Response.TradeNo,
Amount: rsp.Response.TotalAmount,
Subject: rsp.Response.Subject,
Message: "OK",
}
} else {
return NotifyVo{
Status: Failure,
Message: "异步查询验证订单信息发生错误" + outTradeNo,
}
}
}
func readKey(filename string) (string, error) {
data, err := os.ReadFile(filename)
if err != nil {
return "", err
}
return string(data), nil
}
var _ PayService = (*AlipayService)(nil)

View File

@@ -22,41 +22,30 @@ import (
"time"
)
// GeekPayService Geek 支付服务
type GeekPayService struct {
config *types.GeekPayConfig
// EPayService 支付服务
type EPayService struct {
config *types.EpayConfig
}
func NewJPayService(appConfig *types.AppConfig) *GeekPayService {
return &GeekPayService{
config: &appConfig.GeekPayConfig,
func NewEPayService(sysConfig *types.SystemConfig) *EPayService {
return &EPayService{
config: &sysConfig.Payment.Epay,
}
}
type GeekPayParams struct {
Method string `json:"method"` // 接口类型
Device string `json:"device"` // 设备类型
Type string `json:"type"` // 支付方式
OutTradeNo string `json:"out_trade_no"` // 商户订单号
Name string `json:"name"` // 商品名称
Money string `json:"money"` // 商品金额
ClientIP string `json:"clientip"` //用户IP地址
SubOpenId string `json:"sub_openid"` // 微信用户 openid仅小程序支付需要
SubAppId string `json:"sub_appid"` // 小程序 AppId仅小程序支付需要
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
func (s *EPayService) UpdateConfig(config *types.EpayConfig) {
s.config = config
}
// Pay 支付订单
func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
func (s *EPayService) Pay(params PayRequest) (string, error) {
p := map[string]string{
"pid": s.config.AppId,
//"method": params.Method,
"pid": s.config.AppId,
"device": params.Device,
"type": params.Type,
"type": params.PayWay,
"out_trade_no": params.OutTradeNo,
"name": params.Name,
"money": params.Money,
"name": params.Subject,
"money": params.TotalFee,
"clientip": params.ClientIP,
"notify_url": params.NotifyURL,
"return_url": params.ReturnURL,
@@ -64,10 +53,21 @@ func (s *GeekPayService) Pay(params GeekPayParams) (*GeekPayResp, error) {
}
p["sign"] = s.Sign(p)
p["sign_type"] = "MD5"
return s.sendRequest(s.config.ApiURL, p)
resp, err := s.sendRequest(s.config.ApiURL, p)
if err != nil {
return "", err
}
if resp.Code != 1 {
return "", errors.New(resp.Msg)
}
if resp.PayURL != "" {
return resp.PayURL, nil
} else {
return resp.QrCode, nil
}
}
func (s *GeekPayService) Sign(params map[string]string) string {
func (s *EPayService) Sign(params map[string]string) string {
// 按字母顺序排序参数
var keys []string
for k := range params {
@@ -100,7 +100,7 @@ type GeekPayResp struct {
UrlScheme string `json:"urlscheme"` // 小程序跳转支付链接
}
func (s *GeekPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
func (s *EPayService) sendRequest(endpoint string, params map[string]string) (*GeekPayResp, error) {
form := url.Values{}
for k, v := range params {
form.Add(k, v)
@@ -137,3 +137,61 @@ func (s *GeekPayService) sendRequest(endpoint string, params map[string]string)
}
return &r, nil
}
func (s *EPayService) Query(outTradeNo string) (OrderInfo, error) {
params := url.Values{}
params.Set("act", "order")
params.Set("pid", s.config.AppId)
params.Set("key", s.config.PrivateKey)
params.Set("out_trade_no", outTradeNo)
apiURL := fmt.Sprintf("%s/api.php?%s", s.config.ApiURL, params.Encode())
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
client := &http.Client{Transport: tr}
resp, err := client.Get(apiURL)
if err != nil {
return OrderInfo{}, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return OrderInfo{}, err
}
logger.Debugf(string(body))
var result struct {
Code int `json:"code"`
Msg string `json:"msg"`
Status string `json:"status"`
Name string `json:"name"`
Money string `json:"money"`
EndTime string `json:"endtime"`
TradeNo string `json:"trade_no"`
}
if err := json.Unmarshal(body, &result); err != nil {
return OrderInfo{}, errors.New("订单查询响应解析失败")
}
if result.Code != 1 {
return OrderInfo{}, errors.New(result.Msg)
}
logger.Debugf("订单信息:%+v", result)
orderInfo := OrderInfo{
OutTradeNo: outTradeNo,
TradeId: result.TradeNo,
Amount: result.Money,
PayTime: result.EndTime,
}
if result.Status == "1" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
var _ PayService = (*EPayService)(nil)

View File

@@ -1,171 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"crypto/md5"
"encoding/hex"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"time"
)
type HuPiPayService struct {
appId string
appSecret string
apiURL string
}
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
return &HuPiPayService{
appId: config.HuPiPayConfig.AppId,
appSecret: config.HuPiPayConfig.AppSecret,
apiURL: config.HuPiPayConfig.ApiURL,
}
}
type HuPiPayParams struct {
AppId string `json:"appid"`
Version string `json:"version"`
TradeOrderId string `json:"trade_order_id"`
TotalFee string `json:"total_fee"`
Title string `json:"title"`
NotifyURL string `json:"notify_url"`
ReturnURL string `json:"return_url"`
WapName string `json:"wap_name"`
CallbackURL string `json:"callback_url"`
Time string `json:"time"`
NonceStr string `json:"nonce_str"`
Type string `json:"type"`
WapUrl string `json:"wap_url"`
}
type HuPiPayResp struct {
Openid interface{} `json:"openid"`
UrlQrcode string `json:"url_qrcode"`
URL string `json:"url"`
ErrCode int `json:"errcode"`
ErrMsg string `json:"errmsg,omitempty"`
}
// Pay 执行支付请求操作
func (s *HuPiPayService) Pay(params HuPiPayParams) (HuPiPayResp, error) {
data := url.Values{}
simple := strconv.FormatInt(time.Now().Unix(), 10)
params.AppId = s.appId
params.Time = simple
params.NonceStr = simple
encode := utils.JsonEncode(params)
m := make(map[string]string)
_ = utils.JsonDecode(encode, &m)
for k, v := range m {
data.Add(k, fmt.Sprintf("%v", v))
}
// 生成签名
data.Add("hash", s.Sign(data))
// 发送支付请求
apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with requst api: %v", err)
}
defer resp.Body.Close()
all, err := io.ReadAll(resp.Body)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with reading response: %v", err)
}
var res HuPiPayResp
err = utils.JsonDecode(string(all), &res)
if err != nil {
return HuPiPayResp{}, fmt.Errorf("error with decode payment result: %v", err)
}
if res.ErrCode != 0 {
return HuPiPayResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
}
return res, nil
}
// Sign 签名方法
func (s *HuPiPayService) Sign(params url.Values) string {
params.Del(`Sign`)
var keys = make([]string, 0, 0)
for key := range params {
if params.Get(key) != `` {
keys = append(keys, key)
}
}
sort.Strings(keys)
var pList = make([]string, 0, 0)
for _, key := range keys {
var value = strings.TrimSpace(params.Get(key))
if len(value) > 0 {
pList = append(pList, key+"="+value)
}
}
var src = strings.Join(pList, "&")
src += s.appSecret
md5bs := md5.Sum([]byte(src))
return hex.EncodeToString(md5bs[:])
}
// Check 校验订单状态
func (s *HuPiPayService) Check(outTradeNo string) error {
data := url.Values{}
data.Add("appid", s.appId)
data.Add("out_trade_order", outTradeNo)
stamp := strconv.FormatInt(time.Now().Unix(), 10)
data.Add("time", stamp)
data.Add("nonce_str", stamp)
data.Add("hash", s.Sign(data))
apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
resp, err := http.PostForm(apiURL, data)
if err != nil {
return fmt.Errorf("error with http reqeust: %v", err)
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("error with reading response: %v", err)
}
var r struct {
ErrCode int `json:"errcode"`
Data struct {
Status string `json:"status"`
OpenOrderId string `json:"open_order_id"`
} `json:"data,omitempty"`
ErrMsg string `json:"errmsg"`
Hash string `json:"hash"`
}
err = utils.JsonDecode(string(body), &r)
if err != nil {
return fmt.Errorf("error with decode response: %v", err)
}
if r.ErrCode == 0 && r.Data.Status == "OD" {
return nil
} else {
logger.Debugf("%+v", r)
return errors.New("order not paid" + r.ErrMsg)
}
}

View File

@@ -0,0 +1,54 @@
package payment
// 支付渠道定义
const PayChannelAL = "alipay" // 支付宝
const PayChannelWX = "wxpay" // 微信支付
const PayChannelEpay = "epay" // 易支付
// 支付方式
const PayWayAL = "alipay"
const PayWayWX = "wxpay"
const (
Success = 0
Failure = 1
Closed = 2
)
type PayRequest struct {
OutTradeNo string // 商户订单号
Subject string // 商品名称
TotalFee string // 商品金额
ReturnURL string // 回调地址
NotifyURL string // 回调地址
// 易支付专有参数
Method string // 接口类型
Device string // 设备类型
PayWay string // 支付方式
ClientIP string //用户IP地址
OpenID string // 用户openid
}
type OrderInfo struct {
Mchid string // 商户号
OutTradeNo string // 商户订单号
TradeId string // 交易号
Amount string // 金额
Status int // 状态 0: 未支付 1: 已支付 2: 已关闭
PayTime string // 完成支付时间
}
func (o OrderInfo) Closed() bool {
return o.Status == Closed
}
func (o OrderInfo) Success() bool {
return o.Status == Success
}
type PayService interface {
Pay(params PayRequest) (string, error) // 生成支付链接
Query(outTradeNo string) (OrderInfo, error) // 查询订单
}

View File

@@ -1,19 +0,0 @@
package payment
type NotifyVo struct {
Status int
OutTradeNo string // 商户订单号
TradeId string // 交易ID
Amount string // 交易金额
Message string
Subject string
}
func (v NotifyVo) Success() bool {
return v.Status == Success
}
const (
Success = 0
Failure = 1
)

View File

@@ -1,144 +0,0 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
"net/http"
"time"
)
type WechatPayService struct {
config *types.WechatPayConfig
client *wechat.ClientV3
}
func NewWechatService(appConfig *types.AppConfig) (*WechatPayService, error) {
config := appConfig.WechatPayConfig
if !config.Enabled {
logger.Info("Disabled WechatPay service")
return nil, nil
}
priKey, err := readKey(config.PrivateKey)
if err != nil {
return nil, fmt.Errorf("error with read App Private key: %v", err)
}
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, priKey)
if err != nil {
return nil, fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return nil, fmt.Errorf("error with autoVerifySign: %v", err)
}
//client.DebugSwitch = gopay.DebugOn
return &WechatPayService{config: &config, client: client}, nil
}
type WechatPayParams struct {
OutTradeNo string `json:"out_trade_no"`
TotalFee int `json:"total_fee"`
Subject string `json:"subject"`
ClientIP string `json:"client_ip"`
ReturnURL string `json:"return_url"`
NotifyURL string `json:"notify_url"`
}
func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
})
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", params.TotalFee).
Set("currency", "CNY")
}).
SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP).
SetBodyMap("h5_info", func(bm gopay.BodyMap) {
bm.Set("type", "Wap")
})
})
wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.H5Url, nil
}
type NotifyResponse struct {
Code string `json:"code"`
Message string `xml:"message"`
}
// TradeVerify 交易验证
func (s *WechatPayService) TradeVerify(request *http.Request) NotifyVo {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return NotifyVo{Status: 1, Message: fmt.Sprintf("error with client v3 parse notify: %v", err)}
}
// TODO: 这里验签程序有 Bug一直报错crypto/rsa: verification error先暂时取消验签
//err = notifyReq.VerifySignByPK(s.client.WxPublicKey())
//if err != nil {
// return fmt.Errorf("error with client v3 verify sign: %v", err)
//}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return NotifyVo{Status: Failure, Message: fmt.Sprintf("error with client v3 decrypt: %v", err)}
}
return NotifyVo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
}
}

View File

@@ -0,0 +1,217 @@
package payment
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"fmt"
"geekai/core/types"
"geekai/utils"
"net/http"
"os"
"time"
"github.com/go-pay/gopay"
"github.com/go-pay/gopay/wechat/v3"
)
type WxPayService struct {
config *types.WxPayConfig
client *wechat.ClientV3
}
func NewWxpayService(sysConfig *types.SystemConfig) (*WxPayService, error) {
config := sysConfig.Payment.WxPay
if !config.Enabled {
logger.Debug("Disabled WechatPay service")
}
service := &WxPayService{config: &config}
if config.Enabled {
err := service.UpdateConfig(&config)
if err != nil {
logger.Errorf("微信支付服务初始化失败: %v", err)
}
}
return service, nil
}
func (s *WxPayService) UpdateConfig(config *types.WxPayConfig) error {
client, err := wechat.NewClientV3(config.MchId, config.SerialNo, config.ApiV3Key, config.PrivateKey)
if err != nil {
return fmt.Errorf("error with initialize WechatPay service: %v", err)
}
err = client.AutoVerifySign()
if err != nil {
return fmt.Errorf("error with autoVerifySign: %v", err)
}
s.client = client
if os.Getenv("GEEKAI_DEBUG") == "true" {
logger.Info("WechatPay Debug mode is enabled")
client.DebugSwitch = gopay.DebugOn
}
s.config = config
return nil
}
func (s *WxPayService) Pay(params PayRequest) (string, error) {
expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// 初始化 BodyMap
bm := make(gopay.BodyMap)
bm.Set("appid", s.config.AppId).
Set("mchid", s.config.MchId).
Set("description", params.Subject).
Set("out_trade_no", params.OutTradeNo).
Set("time_expire", expire).
Set("notify_url", params.NotifyURL).
SetBodyMap("amount", func(bm gopay.BodyMap) {
bm.Set("total", utils.IntValue(params.TotalFee, 0)).
Set("currency", "CNY")
})
logger.Debugf("wxpay params: %+v", bm)
if params.Device == "mobile" {
bm.SetBodyMap("scene_info", func(bm gopay.BodyMap) {
bm.Set("payer_client_ip", params.ClientIP)
}).SetBodyMap("payer", func(bm gopay.BodyMap) {
bm.Set("openid", params.OpenID)
})
wxRsp, err := s.client.V3TransactionJsapi(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Jsapi: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.PrepayId, nil
} else if params.Device == "pc" {
wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
if err != nil {
return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
}
if wxRsp.Code != wechat.Success {
return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
}
return wxRsp.Response.CodeUrl, nil
}
return "", nil
}
func (s *WxPayService) Query(outTradeNo string) (OrderInfo, error) {
wxRsp, err := s.client.V3TransactionQueryOrder(context.Background(), wechat.OutTradeNo, outTradeNo)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 transaction query: %v", err)
}
if wxRsp.Code != wechat.Success {
return OrderInfo{}, fmt.Errorf("error status with querying order: %v", wxRsp.Error)
}
if wxRsp.Response.TradeState == "CLOSED" {
return OrderInfo{Status: Closed}, nil
}
orderInfo := OrderInfo{
OutTradeNo: wxRsp.Response.OutTradeNo,
TradeId: wxRsp.Response.TransactionId,
Amount: fmt.Sprintf("%d", wxRsp.Response.Amount.Total/100),
PayTime: wxRsp.Response.SuccessTime,
}
if wxRsp.Response.TradeState == "SUCCESS" {
orderInfo.Status = Success
} else {
orderInfo.Status = Failure
}
return orderInfo, nil
}
// TradeVerify 交易验证
func (s *WxPayService) TradeVerify(request *http.Request) (OrderInfo, error) {
notifyReq, err := wechat.V3ParseNotify(request)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 parse notify: %v", err)
}
// 解密支付密文,验证订单信息
result, err := notifyReq.DecryptPayCipherText(s.config.ApiV3Key)
if err != nil {
return OrderInfo{}, fmt.Errorf("error with client v3 decrypt: %v", err)
}
return OrderInfo{
Status: Success,
OutTradeNo: result.OutTradeNo,
TradeId: result.TransactionId,
Amount: fmt.Sprintf("%.2f", float64(result.Amount.Total)/100),
PayTime: result.SuccessTime,
}, nil
}
// func (s *WechatPayService) PayUrlNative(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// })
// wxRsp, err := s.client.V3TransactionNative(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction Native: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error status with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.CodeUrl, nil
// }
// func (s *WechatPayService) PayUrlH5(params WechatPayParams) (string, error) {
// expire := time.Now().Add(10 * time.Minute).Format(time.RFC3339)
// // 初始化 BodyMap
// bm := make(gopay.BodyMap)
// bm.Set("appid", s.config.AppId).
// Set("mchid", s.config.MchId).
// Set("description", params.Subject).
// Set("out_trade_no", params.OutTradeNo).
// Set("time_expire", expire).
// Set("notify_url", params.NotifyURL).
// SetBodyMap("amount", func(bm gopay.BodyMap) {
// bm.Set("total", params.TotalFee).
// Set("currency", "CNY")
// }).
// SetBodyMap("scene_info", func(bm gopay.BodyMap) {
// bm.Set("payer_client_ip", params.ClientIP).
// SetBodyMap("h5_info", func(bm gopay.BodyMap) {
// bm.Set("type", "Wap")
// })
// })
// wxRsp, err := s.client.V3TransactionH5(context.Background(), bm)
// if err != nil {
// return "", fmt.Errorf("error with client v3 transaction H5: %v", err)
// }
// if wxRsp.Code != wechat.Success {
// return "", fmt.Errorf("error with generating pay url: %v", wxRsp.Error)
// }
// return wxRsp.Response.H5Url, nil
// }
// type NotifyResponse struct {
// Code string `json:"code"`
// Message string `xml:"message"`
// }
var _ PayService = (*WxPayService)(nil)

View File

@@ -10,36 +10,57 @@ package sms
import (
"fmt"
"geekai/core/types"
"github.com/aliyun/alibaba-cloud-sdk-go/services/dysmsapi"
)
type AliYunSmsService struct {
config *types.SmsConfigAli
config types.SmsConfigAli
client *dysmsapi.Client
domain string
zoneId string
}
func NewAliYunSmsService(appConfig *types.AppConfig) (*AliYunSmsService, error) {
config := &appConfig.SMS.Ali
// 创建阿里云短信客户端
func NewAliYunSmsService(sysConfig *types.SystemConfig) (*AliYunSmsService, error) {
config := sysConfig.SMS.Ali
domain := "dysmsapi.aliyuncs.com"
zoneId := "cn-hangzhou"
s := AliYunSmsService{
config: config,
domain: domain,
zoneId: zoneId,
}
if sysConfig.SMS.Active == Ali {
err := s.UpdateConfig(config)
if err != nil {
logger.Errorf("阿里云短信初始化失败: %v", err)
}
}
return &s, nil
}
func (s *AliYunSmsService) UpdateConfig(config types.SmsConfigAli) error {
client, err := dysmsapi.NewClientWithAccessKey(
"cn-hangzhou",
s.zoneId,
config.AccessKey,
config.AccessSecret)
if err != nil {
return nil, fmt.Errorf("failed to create client: %v", err)
return fmt.Errorf("failed to create client: %v", err)
}
return &AliYunSmsService{
config: config,
client: client,
}, nil
s.client = client
s.config = config
return nil
}
func (s *AliYunSmsService) SendVerifyCode(mobile string, code int) error {
if s.client == nil {
return fmt.Errorf("阿里云短信服务未初始化")
}
// 创建短信请求并设置参数
request := dysmsapi.CreateSendSmsRequest()
request.Scheme = "https"
request.Domain = s.config.Domain
request.Domain = s.domain
request.PhoneNumbers = mobile
request.SignName = s.config.Sign
request.TemplateCode = s.config.CodeTempId

View File

@@ -19,20 +19,21 @@ import (
)
type BaoSmsService struct {
config *types.SmsConfigBao
config types.SmsConfigBao
domain string
}
func NewSmsBaoSmsService(appConfig *types.AppConfig) *BaoSmsService {
config := appConfig.SMS.Bao
if config.Domain == "" { // use default domain
config.Domain = "api.smsbao.com"
logger.Infof("Using default domain for SMS-BAO: %s", config.Domain)
}
func NewBaoSmsService(sysConfig *types.SystemConfig) *BaoSmsService {
return &BaoSmsService{
config: &config,
config: sysConfig.SMS.Bao,
domain: "api.smsbao.com",
}
}
func (s *BaoSmsService) UpdateConfig(config types.SmsConfigBao) {
s.config = config
}
var errMsg = map[string]string{
"0": "短信发送成功",
"-1": "参数不全",
@@ -56,7 +57,7 @@ func (s *BaoSmsService) SendVerifyCode(mobile string, code int) error {
params.Set("m", mobile)
params.Set("c", content)
apiURL := fmt.Sprintf("https://%s/sms?%s", s.config.Domain, params.Encode())
apiURL := fmt.Sprintf("https://%s/sms?%s", s.domain, params.Encode())
response, err := http.Get(apiURL)
if err != nil {
return err

View File

@@ -7,8 +7,8 @@ package sms
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
const Ali = "ALI"
const Bao = "BAO"
const Ali = "aliyun"
const Bao = "bao"
type Service interface {
SendVerifyCode(mobile string, code int) error

View File

@@ -1,46 +0,0 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core/types"
logger2 "geekai/logger"
"strings"
)
type ServiceManager struct {
handler Service
}
var logger = logger2.GetLogger()
func NewSendServiceManager(config *types.AppConfig) (*ServiceManager, error) {
active := Ali
if config.SMS.Active != "" {
active = strings.ToUpper(config.SMS.Active)
}
var handler Service
switch active {
case Ali:
client, err := NewAliYunSmsService(config)
if err != nil {
return nil, err
}
handler = client
break
case Bao:
handler = NewSmsBaoSmsService(config)
break
}
return &ServiceManager{handler: handler}, nil
}
func (m *ServiceManager) GetService() Service {
return m.handler
}

View File

@@ -0,0 +1,54 @@
package sms
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"geekai/core/types"
logger2 "geekai/logger"
)
type SmsManager struct {
aliyun *AliYunSmsService
bao *BaoSmsService
active string
}
var logger = logger2.GetLogger()
func NewSmsManager(sysConfig *types.SystemConfig, aliyun *AliYunSmsService, bao *BaoSmsService) (*SmsManager, error) {
return &SmsManager{
active: sysConfig.SMS.Active,
aliyun: aliyun,
bao: bao,
}, nil
}
func (m *SmsManager) GetService() Service {
switch m.active {
case Ali:
return m.aliyun
case Bao:
return m.bao
}
return nil
}
func (m *SmsManager) SetActive(active string) {
m.active = active
}
func (m *SmsManager) UpdateConfig(config types.SMSConfig) {
switch config.Active {
case Ali:
m.aliyun.UpdateConfig(config.Ali)
case Bao:
m.bao.UpdateConfig(config.Bao)
}
m.active = config.Active
}

View File

@@ -27,6 +27,10 @@ func NewSmtpService(appConfig *types.AppConfig) *SmtpService {
}
}
func (s *SmtpService) UpdateConfig(config *types.SmtpConfig) {
s.config = config
}
func (s *SmtpService) SendVerifyCode(to string, code int) error {
subject := fmt.Sprintf("%s 注册验证码", s.config.AppName)
body := fmt.Sprintf("【%s】您的验证码为 %d请不要告诉他人。如非本人操作请忽略此邮件。", s.config.AppName, code)

View File

@@ -112,6 +112,10 @@ type RespVo struct {
Message string `json:"message"`
Data string `json:"data"`
Channel string `json:"channel,omitempty"`
Error struct {
Message string `json:"message"`
Type string `json:"type"`
} `json:"error,omitempty"`
}
func (s *Service) Create(task types.SunoTask) (RespVo, error) {
@@ -126,7 +130,7 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
reqBody := map[string]any{
"task_id": task.RefTaskId,
"continue_clip_id": task.RefSongId,
"continue_at": task.ExtendSecs,
@@ -154,13 +158,14 @@ func (s *Service) Create(task types.SunoTask) (RespVo, error) {
}
body, _ := io.ReadAll(r.Body)
logger.Debugf("API response: %s", string(body))
err = json.Unmarshal(body, &res)
if err != nil {
return RespVo{}, fmt.Errorf("解析API数据失败%v, %s", err, string(body))
}
if res.Code != "success" {
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Message)
return RespVo{}, fmt.Errorf("API 返回失败:%s", res.Error.Message)
}
// update the last_use_at for api key
apiKey.LastUsedAt = time.Now().Unix()
@@ -225,7 +230,7 @@ func (s *Service) Upload(task types.SunoTask) (RespVo, error) {
return RespVo{}, errors.New("no available API KEY for Suno")
}
reqBody := map[string]interface{}{
reqBody := map[string]any{
"url": task.AudioURL,
}
@@ -330,7 +335,13 @@ func (s *Service) SyncTaskProgress() {
job.SongId = v.Id
job.Duration = int(v.Metadata.Duration)
job.Prompt = v.Metadata.Prompt
job.Tags = v.Metadata.Tags
// 修复 tags 字段过长导致插入数据库失败
if len(v.Metadata.Tags) > 255 {
job.Tags = v.Metadata.Tags[:255]
} else {
job.Tags = v.Metadata.Tags
}
job.ModelName = v.ModelName
job.RawData = utils.JsonEncode(v)
job.CoverURL = v.ImageLargeUrl

View File

@@ -1,5 +1,7 @@
package service
import logger2 "geekai/logger"
const FailTaskProgress = 101
const (
TaskStatusRunning = "RUNNING"
@@ -15,6 +17,8 @@ type NotifyMessage struct {
Type string `json:"type"`
}
var logger = logger2.GetLogger()
const TranslatePromptTemplate = "Translate the following painting prompt words into English keyword phrases. Without any explanation, directly output the keyword phrases separated by commas. The content to be translated is: [%s]"
const ImagePromptOptimizeTemplate = `

View File

@@ -0,0 +1,117 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"errors"
"fmt"
"geekai/core/types"
"geekai/utils"
"time"
"github.com/go-redis/redis/v8"
"github.com/imroc/req/v3"
)
type WxLoginService struct {
config types.WxLoginConfig
client *req.Client
redisClient *redis.Client
}
const loginStateKeyPrefix = "wx_login_state/"
type LoginStatus struct {
Status string `json:"status"`
OpenID string `json:"openid"`
Token string `json:"token"`
}
const (
LoginStatusPending = "pending"
LoginStatusSuccess = "success"
LoginStatusExpired = "expired" // 登录失效,需要重新登录
)
func NewWxLoginService(config types.WxLoginConfig, redisClient *redis.Client) *WxLoginService {
return &WxLoginService{
config: config,
client: req.C().SetTimeout(10 * time.Second),
redisClient: redisClient,
}
}
func (s *WxLoginService) UpdateConfig(config types.WxLoginConfig) {
s.config = config
}
func (s *WxLoginService) GetConfig() types.WxLoginConfig {
return s.config
}
func (s *WxLoginService) SetConfig(config types.WxLoginConfig) {
s.config = config
}
func (s *WxLoginService) GetLoginQrCodeUrl(state string) (string, error) {
if s.config.ApiKey == "" {
return "", errors.New("无效的 API Key")
}
url := fmt.Sprintf("%s/api/auth/wechat/login", types.GeekAPIURL)
var res struct {
Code types.BizCode `json:"code"`
Message string `json:"message"`
Data struct {
Ticket string `json:"ticket"`
Url string `json:"url"`
} `json:"data"`
}
r, err := s.client.R().
SetHeader("Authorization", s.config.ApiKey).
SetBody(map[string]string{
"notify_url": s.config.NotifyURL,
"state": state,
}).
SetSuccessResult(&res).Post(url)
if err != nil || r.IsErrorState() {
return "", fmt.Errorf("请求 API 失败:%v", err)
}
if res.Code != types.Success {
return "", fmt.Errorf("请求 API 失败:%s", res.Message)
}
status := LoginStatus{
Status: LoginStatusPending,
OpenID: "",
}
s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour)
return res.Data.Url, nil
}
func (s *WxLoginService) GetLoginStatus(state string) (*LoginStatus, error) {
result, err := s.redisClient.Get(context.Background(), loginStateKeyPrefix+state).Result()
if err != nil {
return nil, errors.New("登录失败")
}
var status LoginStatus
err = utils.JsonDecode(result, &status)
if err != nil {
return nil, errors.New("登录失败")
}
return &status, nil
}
func (s *WxLoginService) SetLoginStatus(state string, status LoginStatus) {
s.redisClient.Set(context.Background(), loginStateKeyPrefix+state, utils.JsonEncode(status), time.Hour)
}

View File

@@ -1,64 +0,0 @@
package service
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
// * Use of this source code is governed by a Apache-2.0 license
// * that can be found in the LICENSE file.
// * @Author yangjian102621@163.com
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import (
"context"
"geekai/core/types"
logger2 "geekai/logger"
"github.com/xxl-job/xxl-job-executor-go"
"gorm.io/gorm"
)
var logger = logger2.GetLogger()
type XXLJobExecutor struct {
executor xxl.Executor
db *gorm.DB
}
func NewXXLJobExecutor(config *types.AppConfig, db *gorm.DB) *XXLJobExecutor {
if !config.XXLConfig.Enabled {
logger.Info("XXL-JOB service is disabled")
return nil
}
exec := xxl.NewExecutor(
xxl.ServerAddr(config.XXLConfig.ServerAddr),
xxl.AccessToken(config.XXLConfig.AccessToken), //请求令牌(默认为空)
xxl.ExecutorIp(config.XXLConfig.ExecutorIp), //可自动获取
xxl.ExecutorPort(config.XXLConfig.ExecutorPort), //默认9999非必填
xxl.RegistryKey(config.XXLConfig.RegistryKey), //执行器名称
xxl.SetLogger(&customLogger{}), //自定义日志
)
exec.Init()
return &XXLJobExecutor{executor: exec, db: db}
}
func (e *XXLJobExecutor) Run() error {
e.executor.RegTask("ClearOrders", e.ClearOrders)
return e.executor.Run()
}
// ClearOrders 清理未支付的订单,如果没有抛出异常则表示执行成功
func (e *XXLJobExecutor) ClearOrders(cxt context.Context, param *xxl.RunReq) (msg string) {
logger.Info("执行清理未支付订单...")
return "success"
}
type customLogger struct{}
func (l *customLogger) Info(format string, a ...interface{}) {
logger.Debugf(format, a...)
}
func (l *customLogger) Error(format string, a ...interface{}) {
logger.Errorf(format, a...)
}