mirror of
				https://github.com/yangjian102621/geekai.git
				synced 2025-11-04 08:13:43 +08:00 
			
		
		
		
	Compare commits
	
		
			489 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 
						 | 
					5f820b9dc1 | ||
| 
						 | 
					4a99be2f15 | ||
| 
						 | 
					27c816cf3b | ||
| 
						 | 
					0d81776212 | ||
| 
						 | 
					cccab31c0f | ||
| 
						 | 
					4ddf3bf2bf | ||
| 
						 | 
					3d37a3d367 | ||
| 
						 | 
					73d8236697 | ||
| 
						 | 
					114d0088dc | ||
| 
						 | 
					43b6665370 | ||
| 
						 | 
					5fb9f84182 | ||
| 
						 | 
					e35c34ad9a | ||
| 
						 | 
					1a4d798f8b | ||
| 
						 | 
					afb91a7023 | ||
| 
						 | 
					dc4c1f7877 | ||
| 
						 | 
					bbc8fe2b40 | ||
| 
						 | 
					3c34e8e0e7 | ||
| 
						 | 
					57c932f07c | ||
| 
						 | 
					922202734a | ||
| 
						 | 
					8b3b0139b0 | ||
| 
						 | 
					31828a3336 | ||
| 
						 | 
					b270960a04 | ||
| 
						 | 
					5c4899df6e | ||
| 
						 | 
					9a797bb4a5 | ||
| 
						 | 
					b0c9ffc5a6 | ||
| 
						 | 
					f527cc5b98 | ||
| 
						 | 
					debe8dc209 | ||
| 
						 | 
					2f0215ac87 | ||
| 
						 | 
					dd5cc206e5 | ||
| 
						 | 
					142cd553a3 | ||
| 
						 | 
					657ecccee3 | ||
| 
						 | 
					1232c3cd9c | ||
| 
						 | 
					3ac04a3938 | ||
| 
						 | 
					b7abc42209 | ||
| 
						 | 
					a48179ce0e | ||
| 
						 | 
					e589f25a05 | ||
| 
						 | 
					cc1a3ce343 | ||
| 
						 | 
					7bb76d581c | ||
| 
						 | 
					0d733c0be0 | ||
| 
						 | 
					8b40ac5b5c | ||
| 
						 | 
					24479814e9 | ||
| 
						 | 
					99df028237 | ||
| 
						 | 
					b354b88876 | ||
| 
						 | 
					5e0be4d10e | ||
| 
						 | 
					468b48151f | ||
| 
						 | 
					fa5c036041 | ||
| 
						 | 
					0fdc588167 | ||
| 
						 | 
					2e023cb8dc | ||
| 
						 | 
					e933f32d9c | ||
| 
						 | 
					bd4b0c4d65 | ||
| 
						 | 
					0b2501c1d8 | ||
| 
						 | 
					9d28e62142 | ||
| 
						 | 
					c1d892069e | ||
| 
						 | 
					61b2dbc9f1 | ||
| 
						 | 
					be3245666e | ||
| 
						 | 
					dacdd6fe74 | ||
| 
						 | 
					6807f7e88a | ||
| 
						 | 
					087f5ab2d1 | ||
| 
						 | 
					47c5a0387b | ||
| 
						 | 
					f9da18ad52 | ||
| 
						 | 
					5c9025ca22 | ||
| 
						 | 
					d02cb573fd | ||
| 
						 | 
					caa538a1d0 | ||
| 
						 | 
					b584b4bfb6 | ||
| 
						 | 
					bda335212d | ||
| 
						 | 
					06f4cdc649 | ||
| 
						 | 
					336a7d5b56 | ||
| 
						 | 
					a0f464830f | ||
| 
						 | 
					9bf7fa4081 | ||
| 
						 | 
					96ead65774 | ||
| 
						 | 
					7ad41927aa | ||
| 
						 | 
					4ca9dfd9c0 | ||
| 
						 | 
					8a9f386d8f | ||
| 
						 | 
					adfee8bf58 | ||
| 
						 | 
					fbfa2a71a9 | ||
| 
						 | 
					9a1368ef17 | ||
| 
						 | 
					31b02b97d3 | ||
| 
						 | 
					42da38c5c3 | ||
| 
						 | 
					0a01b55713 | ||
| 
						 | 
					3b292c2a12 | ||
| 
						 | 
					db0ba0d9a0 | ||
| 
						 | 
					3a23ff6b42 | ||
| 
						 | 
					1e9c5adb0a | ||
| 
						 | 
					abab76ccc6 | ||
| 
						 | 
					6efd92806f | ||
| 
						 | 
					cfe333e89f | ||
| 
						 | 
					a7237fe62f | ||
| 
						 | 
					c3c454b7d7 | ||
| 
						 | 
					d4d708d44b | ||
| 
						 | 
					7f0b6a3a46 | ||
| 
						 | 
					c2a7c089d2 | ||
| 
						 | 
					df5bd4df60 | ||
| 
						 | 
					79b6010104 | ||
| 
						 | 
					97b0a98793 | ||
| 
						 | 
					5230f90540 | ||
| 
						 | 
					803db4e895 | ||
| 
						 | 
					7cee9f2ebb | ||
| 
						 | 
					8be9a21efd | ||
| 
						 | 
					6a3e26b566 | ||
| 
						 | 
					0355c37bef | ||
| 
						 | 
					9b7ee538c4 | ||
| 
						 | 
					d900a3d08e | ||
| 
						 | 
					cdf5b66729 | ||
| 
						 | 
					1cff4b63cd | ||
| 
						 | 
					da14309ef9 | ||
| 
						 | 
					fbb216fe3b | ||
| 
						 | 
					95efbd5659 | ||
| 
						 | 
					4596c1049c | ||
| 
						 | 
					b35d95f0c7 | ||
| 
						 | 
					01419df998 | ||
| 
						 | 
					a6c00c42fa | ||
| 
						 | 
					4cc9db7115 | ||
| 
						 | 
					4f1ed54059 | ||
| 
						 | 
					8227a73e35 | ||
| 
						 | 
					adfd8c1939 | ||
| 
						 | 
					8eed7ff534 | ||
| 
						 | 
					c79c4e74d0 | ||
| 
						 | 
					f1855fd0a1 | ||
| 
						 | 
					1f964c74e9 | ||
| 
						 | 
					4fb2c5803c | ||
| 
						 | 
					b5947545cb | ||
| 
						 | 
					342b76f666 | ||
| 
						 | 
					49b5906bc7 | ||
| 
						 | 
					3075bfb7fc | ||
| 
						 | 
					82e06fad33 | ||
| 
						 | 
					4a9028747b | ||
| 
						 | 
					4a8ff0ccf0 | ||
| 
						 | 
					99341f0484 | ||
| 
						 | 
					f58ac29ad0 | ||
| 
						 | 
					7060edb3e5 | ||
| 
						 | 
					41ae411f9b | ||
| 
						 | 
					79b7fee47c | ||
| 
						 | 
					0044bf10af | ||
| 
						 | 
					e9348d3611 | ||
| 
						 | 
					b9236e09a7 | ||
| 
						 | 
					09b38d5f42 | ||
| 
						 | 
					7bb539a06e | ||
| 
						 | 
					5cdada8265 | ||
| 
						 | 
					4147c217b1 | ||
| 
						 | 
					8dda639b23 | ||
| 
						 | 
					8487d2c9eb | ||
| 
						 | 
					c5e583b215 | ||
| 
						 | 
					549f618cff | ||
| 
						 | 
					e9a3510346 | ||
| 
						 | 
					30e6e963b3 | ||
| 
						 | 
					c72d963f45 | ||
| 
						 | 
					172d498618 | ||
| 
						 | 
					313993532e | ||
| 
						 | 
					e53db3582c | ||
| 
						 | 
					72c6bd3f77 | ||
| 
						 | 
					ca8b349df3 | ||
| 
						 | 
					1b206c3640 | ||
| 
						 | 
					c60276fc9f | ||
| 
						 | 
					d00a3167c0 | ||
| 
						 | 
					6b1cd8c30c | ||
| 
						 | 
					46f12dc9ad | ||
| 
						 | 
					a3e1d8ae21 | ||
| 
						 | 
					72a066b93e | ||
| 
						 | 
					0327a829ac | ||
| 
						 | 
					882e9b8819 | ||
| 
						 | 
					ef58cfadaa | ||
| 
						 | 
					bf958d6113 | ||
| 
						 | 
					71611273d7 | ||
| 
						 | 
					b27c654311 | ||
| 
						 | 
					90930ea9f9 | ||
| 
						 | 
					1ab2185ff1 | ||
| 
						 | 
					0f2f978d4c | ||
| 
						 | 
					f61963b0b0 | ||
| 
						 | 
					2aa413960d | ||
| 
						 | 
					aa4bbba5ec | ||
| 
						 | 
					eba61fea2d | ||
| 
						 | 
					34e3455128 | ||
| 
						 | 
					07dca3e739 | ||
| 
						 | 
					4cb4b145f9 | ||
| 
						 | 
					1ed417cb69 | ||
| 
						 | 
					6cf91a84ca | ||
| 
						 | 
					0b566980fc | ||
| 
						 | 
					f86176b342 | ||
| 
						 | 
					c700b32670 | ||
| 
						 | 
					22641b452a | ||
| 
						 | 
					d3fbb8c19e | ||
| 
						 | 
					e3bb69ff10 | ||
| 
						 | 
					770360c614 | ||
| 
						 | 
					f302a0478f | ||
| 
						 | 
					a88697b43a | ||
| 
						 | 
					cc6f140812 | ||
| 
						 | 
					424f2b3bdc | ||
| 
						 | 
					ec0c13a600 | ||
| 
						 | 
					a1f03bec4c | ||
| 
						 | 
					b5bd4a5e0e | ||
| 
						 | 
					7c2e49bfdb | ||
| 
						 | 
					f80fe6d041 | ||
| 
						 | 
					72f80a96bc | ||
| 
						 | 
					2de655a1cf | ||
| 
						 | 
					da2bd4a501 | ||
| 
						 | 
					e0aa62c40d | ||
| 
						 | 
					9d26a892d1 | ||
| 
						 | 
					4ece7f2847 | ||
| 
						 | 
					32368caf1b | ||
| 
						 | 
					e91f54e79e | ||
| 
						 | 
					bb8f4c57c4 | ||
| 
						 | 
					43bfac99b6 | ||
| 
						 | 
					be379b6d63 | ||
| 
						 | 
					17f3c9b840 | ||
| 
						 | 
					24de97fac2 | ||
| 
						 | 
					bf27b44fee | ||
| 
						 | 
					1802b4fe4d | ||
| 
						 | 
					241a5c7bc9 | ||
| 
						 | 
					557d547bf1 | ||
| 
						 | 
					2e7b75affb | ||
| 
						 | 
					bc21a1d443 | ||
| 
						 | 
					3fc9e10a24 | ||
| 
						 | 
					5fa1aa2060 | ||
| 
						 | 
					be8a0ec184 | ||
| 
						 | 
					b02e3aad95 | ||
| 
						 | 
					08eca511ad | ||
| 
						 | 
					c34e911596 | ||
| 
						 | 
					8a452c3072 | ||
| 
						 | 
					13bfb14107 | ||
| 
						 | 
					4188b0969e | ||
| 
						 | 
					0c27795a10 | ||
| 
						 | 
					d05693c5c1 | ||
| 
						 | 
					c0b2063b38 | ||
| 
						 | 
					4d183747b1 | ||
| 
						 | 
					08fe1b2f75 | ||
| 
						 | 
					db3e8a267e | ||
| 
						 | 
					8fc62682c4 | ||
| 
						 | 
					75031914a3 | ||
| 
						 | 
					a4c9fdd95a | ||
| 
						 | 
					6a9bfeb5aa | ||
| 
						 | 
					e654766f60 | ||
| 
						 | 
					0ef6955f96 | ||
| 
						 | 
					b4501557c9 | ||
| 
						 | 
					a2ed99e6cb | ||
| 
						 | 
					6bd6bb3885 | ||
| 
						 | 
					399cf65fc9 | ||
| 
						 | 
					24906a6df1 | ||
| 
						 | 
					d772bbebe6 | ||
| 
						 | 
					14988853a3 | ||
| 
						 | 
					7b3f16ac9f | ||
| 
						 | 
					82b2755c18 | ||
| 
						 | 
					ff4b267858 | ||
| 
						 | 
					a590d0497f | ||
| 
						 | 
					ac30d906f0 | ||
| 
						 | 
					5bc071e038 | ||
| 
						 | 
					88b956cf98 | ||
| 
						 | 
					f725cf4661 | ||
| 
						 | 
					057cc1e8a6 | ||
| 
						 | 
					de122735b8 | ||
| 
						 | 
					e87ede981c | ||
| 
						 | 
					606fb498e1 | ||
| 
						 | 
					a0c06e40a4 | ||
| 
						 | 
					aba8f57279 | ||
| 
						 | 
					960286a350 | ||
| 
						 | 
					8c93fa51f6 | ||
| 
						 | 
					cb0e7d64ff | ||
| 
						 | 
					8e7413da97 | ||
| 
						 | 
					a36f14eb94 | ||
| 
						 | 
					f2f9f6e488 | ||
| 
						 | 
					85068b8ca2 | ||
| 
						 | 
					f2cfcfeefc | ||
| 
						 | 
					755273a898 | ||
| 
						 | 
					d4a24a0f1d | ||
| 
						 | 
					92281fcbb7 | ||
| 
						 | 
					636db4afcc | ||
| 
						 | 
					ba25b8755e | ||
| 
						 | 
					6399d13a49 | ||
| 
						 | 
					06fa54fd25 | ||
| 
						 | 
					a335b965d0 | ||
| 
						 | 
					725adaa7d0 | ||
| 
						 | 
					7e7e81e974 | ||
| 
						 | 
					8cfe6bfc17 | ||
| 
						 | 
					33de83f2ac | ||
| 
						 | 
					3f856afec8 | ||
| 
						 | 
					4e4dc4cb73 | ||
| 
						 | 
					02a9c422fe | ||
| 
						 | 
					ca69341024 | ||
| 
						 | 
					169bf069ce | ||
| 
						 | 
					1bee0ab04d | ||
| 
						 | 
					440d91dd0e | ||
| 
						 | 
					8168e246a8 | ||
| 
						 | 
					2ef07574ae | ||
| 
						 | 
					37392f2bb2 | ||
| 
						 | 
					a80cd3848e | ||
| 
						 | 
					db6ed84451 | ||
| 
						 | 
					4463cc5963 | ||
| 
						 | 
					d316158fe2 | ||
| 
						 | 
					e02a8d7586 | ||
| 
						 | 
					9988dff885 | ||
| 
						 | 
					35ef5674ff | ||
| 
						 | 
					976da45bce | ||
| 
						 | 
					c83ac48bd2 | ||
| 
						 | 
					3d159a833e | ||
| 
						 | 
					4b09878bdd | ||
| 
						 | 
					b0162e6a92 | ||
| 
						 | 
					8ab15e5dc4 | ||
| 
						 | 
					d2ac807252 | ||
| 
						 | 
					0af01f6f1f | ||
| 
						 | 
					013b319fab | ||
| 
						 | 
					2899ba5949 | ||
| 
						 | 
					a558b7e104 | ||
| 
						 | 
					7a833e2233 | ||
| 
						 | 
					bf65746d00 | ||
| 
						 | 
					f08a7862de | ||
| 
						 | 
					023a2c2f09 | ||
| 
						 | 
					1bcd0f4c1a | ||
| 
						 | 
					a0f3bc8ccb | ||
| 
						 | 
					dea72738c1 | ||
| 
						 | 
					a1d1fe7763 | ||
| 
						 | 
					a39ed9764c | ||
| 
						 | 
					aaa5ba99aa | ||
| 
						 | 
					2113508b6d | ||
| 
						 | 
					7fe4212684 | ||
| 
						 | 
					8bdda64794 | ||
| 
						 | 
					ec08c24dca | ||
| 
						 | 
					a992a5b3b3 | ||
| 
						 | 
					0f05970141 | ||
| 
						 | 
					e5e762efcd | ||
| 
						 | 
					b3d0c1ef9c | ||
| 
						 | 
					397078f7ff | ||
| 
						 | 
					3ad8065e20 | ||
| 
						 | 
					66c7717f04 | ||
| 
						 | 
					412f8ecc6c | ||
| 
						 | 
					51dcf642b3 | ||
| 
						 | 
					bfeea555b2 | ||
| 
						 | 
					479f94c372 | ||
| 
						 | 
					0140713e86 | ||
| 
						 | 
					15b2ec9721 | ||
| 
						 | 
					c9cd082855 | ||
| 
						 | 
					d7c002890c | ||
| 
						 | 
					348dd22279 | ||
| 
						 | 
					3e99b4cbf6 | ||
| 
						 | 
					6968da3ac7 | ||
| 
						 | 
					bf1c1b84c3 | ||
| 
						 | 
					c70314d930 | ||
| 
						 | 
					9104ca8e49 | ||
| 
						 | 
					2af33b3630 | ||
| 
						 | 
					654e795545 | ||
| 
						 | 
					c62ba2451e | ||
| 
						 | 
					d72d1b8a99 | ||
| 
						 | 
					b939d6016b | ||
| 
						 | 
					36a2626ccc | ||
| 
						 | 
					bd057a4cc9 | ||
| 
						 | 
					dc24a8c781 | ||
| 
						 | 
					59fa21779b | ||
| 
						 | 
					a140671aad | ||
| 
						 | 
					5fe8990fb4 | ||
| 
						 | 
					12799b7159 | ||
| 
						 | 
					9929746b1d | ||
| 
						 | 
					d70035ff0c | ||
| 
						 | 
					eec90274d8 | ||
| 
						 | 
					e8fff55c42 | ||
| 
						 | 
					3cf3cdd705 | ||
| 
						 | 
					9801fce659 | ||
| 
						 | 
					4c1f51110b | ||
| 
						 | 
					913d538587 | ||
| 
						 | 
					9e704365fc | ||
| 
						 | 
					485bdbc56a | ||
| 
						 | 
					7000168fd4 | ||
| 
						 | 
					5694f97a6b | ||
| 
						 | 
					b677d3fac7 | ||
| 
						 | 
					dc6719cf54 | ||
| 
						 | 
					7de5b55091 | ||
| 
						 | 
					76c5101092 | ||
| 
						 | 
					2f8d2f4854 | ||
| 
						 | 
					b1ee34ba0c | ||
| 
						 | 
					069ad6a09a | ||
| 
						 | 
					bf1403c818 | ||
| 
						 | 
					bcc622a24d | ||
| 
						 | 
					a06a81a415 | ||
| 
						 | 
					d1950acd01 | ||
| 
						 | 
					039b70eed2 | ||
| 
						 | 
					d8e4308b1b | ||
| 
						 | 
					434fbb3463 | ||
| 
						 | 
					de3eb8969c | ||
| 
						 | 
					fbd6eac877 | ||
| 
						 | 
					1fecab177b | ||
| 
						 | 
					b1b385c455 | ||
| 
						 | 
					3c6e86d04b | ||
| 
						 | 
					3d2035d08a | ||
| 
						 | 
					da86f916d8 | ||
| 
						 | 
					e7a07f7e92 | ||
| 
						 | 
					b01e6387fc | ||
| 
						 | 
					d86aca0f5d | ||
| 
						 | 
					09414fe36a | ||
| 
						 | 
					df0e7508db | ||
| 
						 | 
					92b1f01118 | ||
| 
						 | 
					8fb8bd932b | ||
| 
						 | 
					3f74b94784 | ||
| 
						 | 
					e9467341fa | ||
| 
						 | 
					131e051ddc | ||
| 
						 | 
					f626fe3166 | ||
| 
						 | 
					6bc57b6132 | ||
| 
						 | 
					d972e97c88 | ||
| 
						 | 
					3991f4daec | ||
| 
						 | 
					f6b567d6fc | ||
| 
						 | 
					8addba8203 | ||
| 
						 | 
					3ab930a107 | ||
| 
						 | 
					de512a5ea2 | ||
| 
						 | 
					113cfae2dc | ||
| 
						 | 
					33aebf9cb5 | ||
| 
						 | 
					6e58ddf681 | ||
| 
						 | 
					cae5c049e4 | ||
| 
						 | 
					ff76e4bd89 | ||
| 
						 | 
					a0a506a3c4 | ||
| 
						 | 
					aa5a4a9977 | ||
| 
						 | 
					abf4f061c1 | ||
| 
						 | 
					245cd3ee1a | ||
| 
						 | 
					45cb29d9a0 | ||
| 
						 | 
					d974b1ff0e | ||
| 
						 | 
					56269170cb | ||
| 
						 | 
					4290c4ca22 | ||
| 
						 | 
					7f7c8e831e | ||
| 
						 | 
					8f057ca9d1 | ||
| 
						 | 
					4a56621ec3 | ||
| 
						 | 
					a398e7a550 | ||
| 
						 | 
					96816c12ca | ||
| 
						 | 
					9984926f69 | ||
| 
						 | 
					a2a6081027 | ||
| 
						 | 
					5a10ed37a7 | ||
| 
						 | 
					1a9dd9de0b | ||
| 
						 | 
					0dae5bef71 | ||
| 
						 | 
					b4413ed726 | ||
| 
						 | 
					5e1fe88b8b | ||
| 
						 | 
					91ed41b536 | ||
| 
						 | 
					024c0032eb | ||
| 
						 | 
					4a9f7e3bce | ||
| 
						 | 
					cf4dcc34ec | ||
| 
						 | 
					4d612c15af | ||
| 
						 | 
					8aec87cc02 | ||
| 
						 | 
					442e411cde | ||
| 
						 | 
					acec0194de | ||
| 
						 | 
					8557f5b94a | ||
| 
						 | 
					babef8baae | ||
| 
						 | 
					efd4ab46f5 | ||
| 
						 | 
					ae8239e5de | ||
| 
						 | 
					f0994ba457 | ||
| 
						 | 
					dae91ed243 | ||
| 
						 | 
					de42a428e6 | ||
| 
						 | 
					63c7041e1f | ||
| 
						 | 
					b1263ddc69 | ||
| 
						 | 
					7e50e17aaf | ||
| 
						 | 
					a7265c4251 | ||
| 
						 | 
					6f39f639bd | ||
| 
						 | 
					a7db123437 | ||
| 
						 | 
					241c714a8b | ||
| 
						 | 
					67ac3cfe32 | ||
| 
						 | 
					c926e0afcc | ||
| 
						 | 
					5bc07e6d57 | ||
| 
						 | 
					c3666a9a71 | ||
| 
						 | 
					23b5ffa97d | ||
| 
						 | 
					a2c7a75705 | ||
| 
						 | 
					d68f2ef12c | ||
| 
						 | 
					67d30353f0 | ||
| 
						 | 
					4813163eac | ||
| 
						 | 
					5c5210625e | ||
| 
						 | 
					a4a1eec30b | ||
| 
						 | 
					d35164506a | ||
| 
						 | 
					1ed08f01ea | ||
| 
						 | 
					eca07ab830 | ||
| 
						 | 
					3512715704 | ||
| 
						 | 
					6d07881141 | ||
| 
						 | 
					251fe626f2 | ||
| 
						 | 
					5fee3a9288 | ||
| 
						 | 
					9b68d8101e | ||
| 
						 | 
					cfe6f27d48 | ||
| 
						 | 
					b314dd0900 | ||
| 
						 | 
					950fab6374 | ||
| 
						 | 
					9d1f5c42ce | ||
| 
						 | 
					a84046390b | ||
| 
						 | 
					aa29323a8a | ||
| 
						 | 
					d5617b7c3a | ||
| 
						 | 
					1ef60a9e5e | ||
| 
						 | 
					fb6e395ad8 | ||
| 
						 | 
					d9216060bc | ||
| 
						 | 
					bcaa9a92e5 | ||
| 
						 | 
					576adc9036 | ||
| 
						 | 
					00de18be9a | ||
| 
						 | 
					c61d32816a | ||
| 
						 | 
					f3fbb0b89c | ||
| 
						 | 
					e311a39632 | ||
| 
						 | 
					51407abe44 | ||
| 
						 | 
					8a470b1038 | ||
| 
						 | 
					baddabaa16 | ||
| 
						 | 
					427b434ce3 | ||
| 
						 | 
					5f921965e6 | ||
| 
						 | 
					1e705c8ed5 | ||
| 
						 | 
					c584b82ddb | ||
| 
						 | 
					72418ce4d7 | 
							
								
								
									
										6
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								.dockerignore
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,6 @@
 | 
			
		||||
deploy
 | 
			
		||||
docs
 | 
			
		||||
api/static 
 | 
			
		||||
web/node_modules
 | 
			
		||||
desktop
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ISSUE_TEMPLATE/1.bug.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ISSUE_TEMPLATE/1.bug.yml
									
									
									
									
										vendored
									
									
								
							@@ -1,5 +1,5 @@
 | 
			
		||||
name: Bug 报告 🐛
 | 
			
		||||
description: 为 chatgpt-plus 提交错误报告
 | 
			
		||||
description: 为 geekai 提交错误报告
 | 
			
		||||
labels: ['Bug']
 | 
			
		||||
body:
 | 
			
		||||
  - type: checkboxes
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										2
									
								
								.github/ISSUE_TEMPLATE/2.feature.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										2
									
								
								.github/ISSUE_TEMPLATE/2.feature.yml
									
									
									
									
										vendored
									
									
								
							@@ -1,5 +1,5 @@
 | 
			
		||||
name: 功能优化 🚀
 | 
			
		||||
description: 为 chatgpt-plus 提交优化建议
 | 
			
		||||
description: 为 geekai 提交优化建议
 | 
			
		||||
labels: ['feature']
 | 
			
		||||
body:
 | 
			
		||||
  - type: checkboxes
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										220
									
								
								CHANGELOG.md
									
									
									
									
									
								
							
							
						
						
									
										220
									
								
								CHANGELOG.md
									
									
									
									
									
								
							@@ -1,11 +1,226 @@
 | 
			
		||||
# 更新日志
 | 
			
		||||
 | 
			
		||||
## v4.0.6
 | 
			
		||||
 | 
			
		||||
* Bug修复:修复PC端画廊页面的瀑布流组件样式错乱问题
 | 
			
		||||
* 功能新增:给思维导图增加 ToolBar,实现思维导图的放大缩小和定位
 | 
			
		||||
* Bug修复:修复思维导图不扣费的Bug
 | 
			
		||||
* Bug修复:修复管理后台角色删除失败的Bug
 | 
			
		||||
* Bug修复:兼容最新版秋叶SD懒人包的 SD API,新增 scheduler 参数
 | 
			
		||||
* 功能优化:支持在管理后台配置 AI 绘图相关配置,包括 SD, MJ-PLUS, MJ-PROXY
 | 
			
		||||
* Bug修复:修复注册用户提示注册人数达到上限的 Bug
 | 
			
		||||
* 功能优化:将MJ,SD,Dall绘画页面的任务列表全改成瀑布流组件
 | 
			
		||||
 | 
			
		||||
## v4.0.5
 | 
			
		||||
 | 
			
		||||
* 功能优化:已授权系统在后台显示授权信息
 | 
			
		||||
* 功能优化:使用思维链提示词生成思维导图,确保生成的思维导图不会出现格式错误
 | 
			
		||||
* 功能优化:优化首页登录注册页面的 UI
 | 
			
		||||
* BUG修复:修复License验证的逻辑漏洞
 | 
			
		||||
* Bug修复:后台添加用户的时候密码规则限制跟前台注册保持一致
 | 
			
		||||
* 功能新增:管理后台支持切换主题,支持 light 和 dark 两种主题
 | 
			
		||||
* 功能新增:移动端新增 DALL-E 绘画功能
 | 
			
		||||
* 功能新增:新增移动端首页功能,移动端支持 light 和 dark 两种主题
 | 
			
		||||
* 功能新增:移动支持免登录预览功能
 | 
			
		||||
* Bug修复:解决在同一个浏览器开启多个对话时候对话内容会相互乱串的问题
 | 
			
		||||
* Bug修复:修复部分中转 API 模型会出现第一输出的字符被淹没的Bug
 | 
			
		||||
 | 
			
		||||
## v4.0.4
 | 
			
		||||
 | 
			
		||||
* Bug修复:修复统一千问第二句不回复的问题
 | 
			
		||||
* 功能优化:MJ 和 SD 任务正在执行时不更新已完成任务列表,加快页面渲染速度
 | 
			
		||||
* 功能新增:Dalle AI 绘画功能实现
 | 
			
		||||
* Bug修复:修复思维导图格式乱码问题
 | 
			
		||||
* 功能优化:支持使用 TLS 邮件协议,解决国内服务器无法使用 25 号端口发送邮件的问题
 | 
			
		||||
* 功能新增:支持从应用列表直接和某个应用对话
 | 
			
		||||
* 功能优化:优化算力日志的页面和首页的UI
 | 
			
		||||
* 功能新增:支持思维导图导出 PNG 图片下载
 | 
			
		||||
 | 
			
		||||
## v4.0.3
 | 
			
		||||
 | 
			
		||||
* 功能新增:允许为角色应用绑定模型,如指定某个角色只能使用某个模型
 | 
			
		||||
* Bug修复:兼容 gpt-4-turbo-2024-04-09 模型的函数调用 Bug
 | 
			
		||||
* Bug修复:修复MidJourney在任务超时后出现后面的任务覆盖前面任务的问题
 | 
			
		||||
* 功能新增:支持上传图片和视觉模型
 | 
			
		||||
* 功能优化:优化聊天页面的复制代码按钮样式乱码
 | 
			
		||||
* 功能新增:增加思维导图功能,支持选择不同的对话模型来生成思维导图
 | 
			
		||||
* 功能新增:支持为角色绑定对话模型,比如绑定某个角色只能用GPT3.5或者 GPT4
 | 
			
		||||
* 功能新增:支持为模型绑定 API KEY,比如为 GPT3.5 模型绑定免费的 API KEY 给用户免费使用来引流不至于消耗你的收费 KEY。
 | 
			
		||||
* 功能新增:支持管理后台 Logo 修改
 | 
			
		||||
 | 
			
		||||
## 4.0.2
 | 
			
		||||
 | 
			
		||||
* 功能新增:支持前端菜单可以配置
 | 
			
		||||
* 功能优化:在登录和注册界面标题显示软件版本号
 | 
			
		||||
* 功能优化:MJ 绘画支持 --sref 和 --cref 图片一致性参数
 | 
			
		||||
* 功能优化:使用 leveldb 解决 SD 绘图进度图片预览问题
 | 
			
		||||
* Bug修复:解决因为图片上传使用相对路径而导致融图失败的问题。
 | 
			
		||||
* 功能新增:手机端支持 Stable-Diffusion 绘画
 | 
			
		||||
* 功能新增:管理后台登录页面增加行为验证码,防止爆破
 | 
			
		||||
 | 
			
		||||
## v4.0.1
 | 
			
		||||
 | 
			
		||||
* 功能重构:重构 Stable-Diffusion 绘画实现,使用 SDAPI 替换之前的 websocket 接口,SDAPI 兼容各种 stable-diffusion
 | 
			
		||||
  发行版,稳定性更强一些
 | 
			
		||||
* 功能优化:使用 [midjouney-proxy](https://github.com/novicezk/midjourney-proxy) 项目替换内置的原生 MidJourney API,兼容
 | 
			
		||||
  MJ-Plus 中转
 | 
			
		||||
* 功能新增:用户算力消费日志增加统计功能,统计一段时间内用户消费的算力
 | 
			
		||||
* Bug修复:修复 iphone 手机无法通过图形验证码的Bug,使用滑动验证码替换
 | 
			
		||||
* Bug修复:修复手机端 MidJourney 绘画页面滚动条无法滚动的Bug
 | 
			
		||||
 | 
			
		||||
## v4.0.0
 | 
			
		||||
 | 
			
		||||
非兼容版本,重大重构,引入算力概念,将系统中所有的能力(AI对话,MJ绘画,SD绘画,DALL绘画)全部使用算力来兑换。
 | 
			
		||||
只要你的算力值余额不为0,你就可以进行任何操作。比如一次 GPT3.5 对话消耗1个单位算力,一次 GPT4 对话消耗10个算力。一次 MJ
 | 
			
		||||
对话消耗15个算力...
 | 
			
		||||
 | 
			
		||||
* 功能重构:重构整体系统,全部采用算力来进行结算
 | 
			
		||||
* 功能优化:SD 绘画页面采用 websocket 替换 http 轮询机制,节省带宽
 | 
			
		||||
* 功能优化:移动端聊天页面图片支持预览和放大功能
 | 
			
		||||
* 功能优化:MJ 和 SD 页面数据分页加载,解决一次性加载太多数据导致页面卡顿的问题
 | 
			
		||||
* 功能优化:**PC端不登录也可以预览功能,只有在发起操作的时候才需要登录**
 | 
			
		||||
* 功能优化:控制台订单管理页面显示未支付订单,并提供订单删除功能
 | 
			
		||||
* 功能新增:支持H5支付
 | 
			
		||||
* 功能优化:支持数学公式的识别和美化输出
 | 
			
		||||
* 功能新增:新增算力消费日志功能
 | 
			
		||||
* 功能优化:整合 XXL-JOB 实现订单清理,每日算力派发,VIP 算力重置等任务
 | 
			
		||||
* 功能新增:管理后台新增7日内新增用户和新增订单统计
 | 
			
		||||
 | 
			
		||||
## v3.2.7
 | 
			
		||||
 | 
			
		||||
* 功能重构:采用 Vant 重构移动页面,新增 MidJourney 功能
 | 
			
		||||
* 功能优化:优化 PC 端 MidJourney 页面布局,新增融图和换脸功能
 | 
			
		||||
* Bug修复:修复 issue [
 | 
			
		||||
  管理界面操作用户存在的两个问题](https://github.com/yangjian102621/chatgpt-plus/issues/117#issuecomment-1909201532)
 | 
			
		||||
* 功能优化:在对话和聊天记录表中新增冗余字段 model,存储对话模型
 | 
			
		||||
* Bug修复:IPhone 手机验证码触摸事件坐标错位 [issue 144](https://github.com/yangjian102621/chatgpt-plus/issues/144)
 | 
			
		||||
* Bug修复:重新生成按钮功能失效问题
 | 
			
		||||
* Bug修复:对话输入HTML标签不显示的问题
 | 
			
		||||
* 功能优化:gpt-4-all/gpts/midjourney-plus 支持第三方平台的 API KEY
 | 
			
		||||
* 功能新增:新增删除文件功能
 | 
			
		||||
* Bug修复:解决 MJ-Plus discord 图片下载失败问题,使用第三方平台中转地址下载
 | 
			
		||||
* 功能新增:后台管理新怎对话查看和检索功能
 | 
			
		||||
 | 
			
		||||
## v3.2.6
 | 
			
		||||
 | 
			
		||||
* 功能优化:恢复关闭注册系统配置项,管理员可以在后台关闭用户注册,只允许内部添加账号
 | 
			
		||||
* 功能优化:兼用旧版本微信收款消息解析
 | 
			
		||||
* 功能优化:优化订单扫码支付状态轮询功能,当关闭二维码时取消轮询,节约网络资源
 | 
			
		||||
* 功能新增:新增图片发布功能,画廊只显示用户已发布的图片
 | 
			
		||||
* 功能新增:后台新增配置微信客服二维码,可以上传自己的微信客服二维码
 | 
			
		||||
* 功能新增:新增网站公告,可以在管理后台自定义配置
 | 
			
		||||
* 功能新增:新增阿里通义千问大模型支持
 | 
			
		||||
* Bug修复:修复 MJ 放大任务失败时候 img_call 会增加的 Bug
 | 
			
		||||
* 功能优化:新增虎皮椒和PayJS订单状态校验功能,增加安全性
 | 
			
		||||
* Bug修复:修复微信转账交易 ID 提取失败 Bug
 | 
			
		||||
* 功能优化:给所有的 websocket 连接加上心跳,解决 "close 1006 (abnormal closure): unexpected EOF" Bug
 | 
			
		||||
* 功能新增:新增短信宝短信平台发送平台集成
 | 
			
		||||
 | 
			
		||||
## v3.2.5
 | 
			
		||||
 | 
			
		||||
* 功能新增:**重磅更新!!!** 新增 MidJourney-Plus API 支持,一秒配置,开箱即用,高效稳定。
 | 
			
		||||
* 功能新增:**重磅更新!!!** 新增 GPT4-ALL 和 GPTs 模型支持,你只需花几块钱,可以丝滑享受 ChatGPT-Plus 会员的所有功能,无需再订阅
 | 
			
		||||
  Plus 账号了!!!
 | 
			
		||||
* 功能优化:增强 markdown 图片和引用块解析。
 | 
			
		||||
* 功能新增:新增用户文件管理,目前一支持上传文件跟 GPT 进行多态对话。
 | 
			
		||||
* 功能优化:function call 兼用中转 API。
 | 
			
		||||
* Bug修复:修复部分已知的 Bug。
 | 
			
		||||
 | 
			
		||||
## v3.2.4.1
 | 
			
		||||
 | 
			
		||||
* 功能新增:新增 PayJs 支付通道
 | 
			
		||||
* Bug修复:紧急修复后台添加用户失败问题
 | 
			
		||||
* Bug修复:紧急修复使用中转 API-KEY 无法绘图的问题
 | 
			
		||||
* Bug修复:允许用户关闭手机和邮箱注册通道,移除验证码依赖
 | 
			
		||||
 | 
			
		||||
## v3.2.4
 | 
			
		||||
 | 
			
		||||
* 功能新增:重磅更新,支持邮箱注册
 | 
			
		||||
* 功能优化:优化函数调用授权
 | 
			
		||||
* 功能优化:给用户表新增 nickname 字段
 | 
			
		||||
* 功能优化:管理后台给聊天角色增加启用/禁用开关
 | 
			
		||||
* Bug修复:SD绘画出现重复扣减绘图次数
 | 
			
		||||
* 功能优化:优化聊天对话导出样式,适应移动端
 | 
			
		||||
* 功能新增:众筹核销可以选择兑换对话还是绘图的额度
 | 
			
		||||
* Bug修复:修复[从历史记录获取reply有并发风险 #92](https://github.com/yangjian102621/chatgpt-plus/issues/92)
 | 
			
		||||
* Bug修复:修复 MidJourney 绘图任务调度Bug,为 task_id 建议唯一索引
 | 
			
		||||
* 功能重构:重构了 API KEY模块,支持为每个 API KEY 都设置不同的 API 地址,并可以单独开启是否使用代理。
 | 
			
		||||
 | 
			
		||||
## v3.2.3
 | 
			
		||||
 | 
			
		||||
* 功能重构:重构函数工具模块,设计成可以后台动态管理函数。支持添加自定义函数实现
 | 
			
		||||
* 功能新增:为充值产品数据表添加 img_calls 字段,支持充值绘图次数
 | 
			
		||||
* Bug修复:修复 [MJ 机器人空指针异常的 Bug](https://github.com/yangjian102621/chatgpt-plus/issues/73)
 | 
			
		||||
* Bug修复:确保相同 Prompt 的绘图任务的 Upscale 和 Variation 任务调度给相同的频道
 | 
			
		||||
* 功能新增:新增删除绘图任何和图片功能
 | 
			
		||||
* Bug修复:修复虎皮椒支付二维码重复扫码时报错问题
 | 
			
		||||
* 功能优化:自动将 AI 绘画中的中文提示词翻译成英文
 | 
			
		||||
* 功能优化:优化AI绘画的大图压缩算法,新增图片缓存
 | 
			
		||||
* 功能优化:支持为 MJ 绘图 API 增加反代功能,提高图片的加载速度,大大降低绘图任务的失败率
 | 
			
		||||
* Bug修复:修复[Azure Api 更换api-version参数后请求失败的问题](https://github.com/yangjian102621/chatgpt-plus/pull/71)
 | 
			
		||||
* Bug修复:修复科大讯飞 V1.5 API 请求失败的问题
 | 
			
		||||
* Bug修复:绘图失败后,自动恢复用户的剩余绘图次数
 | 
			
		||||
* 功能新增:为移动端新增 SD 绘图功能,分享功能
 | 
			
		||||
 | 
			
		||||
## v3.2.2
 | 
			
		||||
 | 
			
		||||
* 功能重构:重构 MidJourney 和 Stable-Diffusion 绘图模块,支持使用多组配置创建池子提供绘画服务
 | 
			
		||||
* 功能新增:AI绘画页面增加翻译和重写提示词功能
 | 
			
		||||
* 功能优化:OSS上传组件支持在 Bucket 下设置二级目录
 | 
			
		||||
* Bug修复:修复阿里云 OSS 访问路径错误
 | 
			
		||||
* 功能优化:在 AI 绘图页面使用 HTTP 轮询替换 Websocket
 | 
			
		||||
 | 
			
		||||
## v3.2.1
 | 
			
		||||
 | 
			
		||||
* 功能优化:切换角色和模型的时候自动创建新的对话
 | 
			
		||||
* Bug修复:修复文件上传失败No such file bug
 | 
			
		||||
* 功能新增:MidJourney 绘画页面新增提示词翻译功能,新增多个绘画参数
 | 
			
		||||
* Bug修复:[PC端对话在刷新后异常](https://github.com/yangjian102621/chatgpt-plus/issues/59)
 | 
			
		||||
* 功能新增:增加 arm64 架构打包脚本
 | 
			
		||||
* 功能新增:支持 dall-e3 绘图的 API 地址自定义配置
 | 
			
		||||
* 功能新增:新增虎皮椒支付功能接入,支持微信和支付宝通道
 | 
			
		||||
 | 
			
		||||
## v3.2.0
 | 
			
		||||
 | 
			
		||||
* 功能新增:新增邀请注册功能
 | 
			
		||||
* 功能优化:增加中间件自动对HTTP请求的参数去掉首尾空格
 | 
			
		||||
* 功能优化:增加中间件自动为大图片生成缩略图
 | 
			
		||||
* 功能优化:MidJourney 页面图片加载优化,实现图片预览懒加载
 | 
			
		||||
* 功能新增:新增 DALL-E-3 绘画支持,并作为对话页面默认绘画插件
 | 
			
		||||
* Bug修复:修复阿里云 OSS 域名设置不起做用的bug
 | 
			
		||||
* Bug修复:修复MidJourney绘图失败后重复添加到队列的问题
 | 
			
		||||
 | 
			
		||||
## v3.1.9
 | 
			
		||||
 | 
			
		||||
* 功能新增:增加讯飞星火大模型 v3.0 支持
 | 
			
		||||
* 功能新增:新增找回密码功能
 | 
			
		||||
* 功能新增:支持 Markdown 代码复制功能
 | 
			
		||||
* Bug修复: xxl-job 任务调度失败的 Bug
 | 
			
		||||
* 功能优化:优化前端页面菜单图标,使用自定义图标替换 icon-font
 | 
			
		||||
* Bug修复:Stable-Diffusion 绘画成功之后没有扣减用户画图次数
 | 
			
		||||
* 功能优化:优化会员充值页面 ItemList 组件
 | 
			
		||||
* 功能优化:给首页 Logo 增加链接
 | 
			
		||||
* Bug修复:[新建会话时,提示"请输入合法的手机号" ](https://github.com/yangjian102621/chatgpt-plus/issues/51)
 | 
			
		||||
* Bug修复:聊天上下文失效问题
 | 
			
		||||
* 功能优化:关闭注册时显示联系管理员二维码
 | 
			
		||||
* 功能优化:移除 leveldb 依赖,使用 redis 替换相应的功能
 | 
			
		||||
* Bug修复:后台启用用户 VIP 不生效问题
 | 
			
		||||
* 功能优化:充值支付页面的支付说明文字可以后台配置
 | 
			
		||||
* Bug修复:ChatGLM,百度文心,科大讯飞模型输出代码不换行问题
 | 
			
		||||
 | 
			
		||||
## v3.1.8
 | 
			
		||||
 | 
			
		||||
1. 功能新增:新增会员套餐充值,点卡充值,订单系统,集成支付宝支付通道
 | 
			
		||||
2. Bug修复:修复 MidJourney API 参数版本更新导致调用失败的 Bug
 | 
			
		||||
3. 功能优化:将聊天报错信息定义为统一常量,方便修改
 | 
			
		||||
2. Bug修复:修复 MidJourney API 参数版本更新导致调用失败问题
 | 
			
		||||
3. Bug修复:修复 Stable Diffusion 调用后没有更新绘图调用次数问题
 | 
			
		||||
4. Bug修复:修复七牛云上传报错 expired token
 | 
			
		||||
5. Bug修复:修复高权重模型导致的对话次数为负数的漏洞
 | 
			
		||||
6. 功能优化:将聊天报错信息定义为统一常量,方便修改
 | 
			
		||||
7. 功能优化:优化 markdown 表格显示样式,覆写 Element-Plus 表格样式
 | 
			
		||||
8. 功能优化:增加倒数计时组件,定期自动清理未支付的订单
 | 
			
		||||
 | 
			
		||||
## v3.1.7
 | 
			
		||||
 | 
			
		||||
1. 功能新增:支持文心4.0 AI 模型
 | 
			
		||||
2. 功能新增:可以在管理后台为用户绑定指定的 AI 模型,如只给某个用户使用 GPT-4 模型
 | 
			
		||||
3. 功能新增:模型新增权重字段,不同的模型每次调用耗费的点数可以设置不同,比如GPT4是GPT3.5的10倍
 | 
			
		||||
@@ -13,6 +228,7 @@
 | 
			
		||||
5. 功能优化:优化 MidJourney 专业绘画页面图片预览样式
 | 
			
		||||
 | 
			
		||||
## v3.1.6
 | 
			
		||||
 | 
			
		||||
1. 功能新增:新增AI 绘画照片墙功能页面,供用户查看所有的 AI 绘画作品
 | 
			
		||||
2. 功能新增:新增 AI 角色应用功能页面,用户可以添加自己感兴趣的应用
 | 
			
		||||
3. 功能优化:优化瀑布流组件的页面布局
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										214
									
								
								LICENSE
									
									
									
									
									
								
							
							
						
						
									
										214
									
								
								LICENSE
									
									
									
									
									
								
							@@ -1,21 +1,201 @@
 | 
			
		||||
MIT License
 | 
			
		||||
                                 Apache License
 | 
			
		||||
                           Version 2.0, January 2004
 | 
			
		||||
                        http://www.apache.org/licenses/
 | 
			
		||||
 | 
			
		||||
Copyright (c) 2023 RockYang
 | 
			
		||||
   TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
 | 
			
		||||
 | 
			
		||||
Permission is hereby granted, free of charge, to any person obtaining a copy
 | 
			
		||||
of this software and associated documentation files (the "Software"), to deal
 | 
			
		||||
in the Software without restriction, including without limitation the rights
 | 
			
		||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 | 
			
		||||
copies of the Software, and to permit persons to whom the Software is
 | 
			
		||||
furnished to do so, subject to the following conditions:
 | 
			
		||||
   1. Definitions.
 | 
			
		||||
 | 
			
		||||
The above copyright notice and this permission notice shall be included in all
 | 
			
		||||
copies or substantial portions of the Software.
 | 
			
		||||
      "License" shall mean the terms and conditions for use, reproduction,
 | 
			
		||||
      and distribution as defined by Sections 1 through 9 of this document.
 | 
			
		||||
 | 
			
		||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 | 
			
		||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | 
			
		||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | 
			
		||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | 
			
		||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 | 
			
		||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 | 
			
		||||
SOFTWARE.
 | 
			
		||||
      "Licensor" shall mean the copyright owner or entity authorized by
 | 
			
		||||
      the copyright owner that is granting the License.
 | 
			
		||||
 | 
			
		||||
      "Legal Entity" shall mean the union of the acting entity and all
 | 
			
		||||
      other entities that control, are controlled by, or are under common
 | 
			
		||||
      control with that entity. For the purposes of this definition,
 | 
			
		||||
      "control" means (i) the power, direct or indirect, to cause the
 | 
			
		||||
      direction or management of such entity, whether by contract or
 | 
			
		||||
      otherwise, or (ii) ownership of fifty percent (50%) or more of the
 | 
			
		||||
      outstanding shares, or (iii) beneficial ownership of such entity.
 | 
			
		||||
 | 
			
		||||
      "You" (or "Your") shall mean an individual or Legal Entity
 | 
			
		||||
      exercising permissions granted by this License.
 | 
			
		||||
 | 
			
		||||
      "Source" form shall mean the preferred form for making modifications,
 | 
			
		||||
      including but not limited to software source code, documentation
 | 
			
		||||
      source, and configuration files.
 | 
			
		||||
 | 
			
		||||
      "Object" form shall mean any form resulting from mechanical
 | 
			
		||||
      transformation or translation of a Source form, including but
 | 
			
		||||
      not limited to compiled object code, generated documentation,
 | 
			
		||||
      and conversions to other media types.
 | 
			
		||||
 | 
			
		||||
      "Work" shall mean the work of authorship, whether in Source or
 | 
			
		||||
      Object form, made available under the License, as indicated by a
 | 
			
		||||
      copyright notice that is included in or attached to the work
 | 
			
		||||
      (an example is provided in the Appendix below).
 | 
			
		||||
 | 
			
		||||
      "Derivative Works" shall mean any work, whether in Source or Object
 | 
			
		||||
      form, that is based on (or derived from) the Work and for which the
 | 
			
		||||
      editorial revisions, annotations, elaborations, or other modifications
 | 
			
		||||
      represent, as a whole, an original work of authorship. For the purposes
 | 
			
		||||
      of this License, Derivative Works shall not include works that remain
 | 
			
		||||
      separable from, or merely link (or bind by name) to the interfaces of,
 | 
			
		||||
      the Work and Derivative Works thereof.
 | 
			
		||||
 | 
			
		||||
      "Contribution" shall mean any work of authorship, including
 | 
			
		||||
      the original version of the Work and any modifications or additions
 | 
			
		||||
      to that Work or Derivative Works thereof, that is intentionally
 | 
			
		||||
      submitted to Licensor for inclusion in the Work by the copyright owner
 | 
			
		||||
      or by an individual or Legal Entity authorized to submit on behalf of
 | 
			
		||||
      the copyright owner. For the purposes of this definition, "submitted"
 | 
			
		||||
      means any form of electronic, verbal, or written communication sent
 | 
			
		||||
      to the Licensor or its representatives, including but not limited to
 | 
			
		||||
      communication on electronic mailing lists, source code control systems,
 | 
			
		||||
      and issue tracking systems that are managed by, or on behalf of, the
 | 
			
		||||
      Licensor for the purpose of discussing and improving the Work, but
 | 
			
		||||
      excluding communication that is conspicuously marked or otherwise
 | 
			
		||||
      designated in writing by the copyright owner as "Not a Contribution."
 | 
			
		||||
 | 
			
		||||
      "Contributor" shall mean Licensor and any individual or Legal Entity
 | 
			
		||||
      on behalf of whom a Contribution has been received by Licensor and
 | 
			
		||||
      subsequently incorporated within the Work.
 | 
			
		||||
 | 
			
		||||
   2. Grant of Copyright License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      copyright license to reproduce, prepare Derivative Works of,
 | 
			
		||||
      publicly display, publicly perform, sublicense, and distribute the
 | 
			
		||||
      Work and such Derivative Works in Source or Object form.
 | 
			
		||||
 | 
			
		||||
   3. Grant of Patent License. Subject to the terms and conditions of
 | 
			
		||||
      this License, each Contributor hereby grants to You a perpetual,
 | 
			
		||||
      worldwide, non-exclusive, no-charge, royalty-free, irrevocable
 | 
			
		||||
      (except as stated in this section) patent license to make, have made,
 | 
			
		||||
      use, offer to sell, sell, import, and otherwise transfer the Work,
 | 
			
		||||
      where such license applies only to those patent claims licensable
 | 
			
		||||
      by such Contributor that are necessarily infringed by their
 | 
			
		||||
      Contribution(s) alone or by combination of their Contribution(s)
 | 
			
		||||
      with the Work to which such Contribution(s) was submitted. If You
 | 
			
		||||
      institute patent litigation against any entity (including a
 | 
			
		||||
      cross-claim or counterclaim in a lawsuit) alleging that the Work
 | 
			
		||||
      or a Contribution incorporated within the Work constitutes direct
 | 
			
		||||
      or contributory patent infringement, then any patent licenses
 | 
			
		||||
      granted to You under this License for that Work shall terminate
 | 
			
		||||
      as of the date such litigation is filed.
 | 
			
		||||
 | 
			
		||||
   4. Redistribution. You may reproduce and distribute copies of the
 | 
			
		||||
      Work or Derivative Works thereof in any medium, with or without
 | 
			
		||||
      modifications, and in Source or Object form, provided that You
 | 
			
		||||
      meet the following conditions:
 | 
			
		||||
 | 
			
		||||
      (a) You must give any other recipients of the Work or
 | 
			
		||||
          Derivative Works a copy of this License; and
 | 
			
		||||
 | 
			
		||||
      (b) You must cause any modified files to carry prominent notices
 | 
			
		||||
          stating that You changed the files; and
 | 
			
		||||
 | 
			
		||||
      (c) You must retain, in the Source form of any Derivative Works
 | 
			
		||||
          that You distribute, all copyright, patent, trademark, and
 | 
			
		||||
          attribution notices from the Source form of the Work,
 | 
			
		||||
          excluding those notices that do not pertain to any part of
 | 
			
		||||
          the Derivative Works; and
 | 
			
		||||
 | 
			
		||||
      (d) If the Work includes a "NOTICE" text file as part of its
 | 
			
		||||
          distribution, then any Derivative Works that You distribute must
 | 
			
		||||
          include a readable copy of the attribution notices contained
 | 
			
		||||
          within such NOTICE file, excluding those notices that do not
 | 
			
		||||
          pertain to any part of the Derivative Works, in at least one
 | 
			
		||||
          of the following places: within a NOTICE text file distributed
 | 
			
		||||
          as part of the Derivative Works; within the Source form or
 | 
			
		||||
          documentation, if provided along with the Derivative Works; or,
 | 
			
		||||
          within a display generated by the Derivative Works, if and
 | 
			
		||||
          wherever such third-party notices normally appear. The contents
 | 
			
		||||
          of the NOTICE file are for informational purposes only and
 | 
			
		||||
          do not modify the License. You may add Your own attribution
 | 
			
		||||
          notices within Derivative Works that You distribute, alongside
 | 
			
		||||
          or as an addendum to the NOTICE text from the Work, provided
 | 
			
		||||
          that such additional attribution notices cannot be construed
 | 
			
		||||
          as modifying the License.
 | 
			
		||||
 | 
			
		||||
      You may add Your own copyright statement to Your modifications and
 | 
			
		||||
      may provide additional or different license terms and conditions
 | 
			
		||||
      for use, reproduction, or distribution of Your modifications, or
 | 
			
		||||
      for any such Derivative Works as a whole, provided Your use,
 | 
			
		||||
      reproduction, and distribution of the Work otherwise complies with
 | 
			
		||||
      the conditions stated in this License.
 | 
			
		||||
 | 
			
		||||
   5. Submission of Contributions. Unless You explicitly state otherwise,
 | 
			
		||||
      any Contribution intentionally submitted for inclusion in the Work
 | 
			
		||||
      by You to the Licensor shall be under the terms and conditions of
 | 
			
		||||
      this License, without any additional terms or conditions.
 | 
			
		||||
      Notwithstanding the above, nothing herein shall supersede or modify
 | 
			
		||||
      the terms of any separate license agreement you may have executed
 | 
			
		||||
      with Licensor regarding such Contributions.
 | 
			
		||||
 | 
			
		||||
   6. Trademarks. This License does not grant permission to use the trade
 | 
			
		||||
      names, trademarks, service marks, or product names of the Licensor,
 | 
			
		||||
      except as required for reasonable and customary use in describing the
 | 
			
		||||
      origin of the Work and reproducing the content of the NOTICE file.
 | 
			
		||||
 | 
			
		||||
   7. Disclaimer of Warranty. Unless required by applicable law or
 | 
			
		||||
      agreed to in writing, Licensor provides the Work (and each
 | 
			
		||||
      Contributor provides its Contributions) on an "AS IS" BASIS,
 | 
			
		||||
      WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
 | 
			
		||||
      implied, including, without limitation, any warranties or conditions
 | 
			
		||||
      of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
 | 
			
		||||
      PARTICULAR PURPOSE. You are solely responsible for determining the
 | 
			
		||||
      appropriateness of using or redistributing the Work and assume any
 | 
			
		||||
      risks associated with Your exercise of permissions under this License.
 | 
			
		||||
 | 
			
		||||
   8. Limitation of Liability. In no event and under no legal theory,
 | 
			
		||||
      whether in tort (including negligence), contract, or otherwise,
 | 
			
		||||
      unless required by applicable law (such as deliberate and grossly
 | 
			
		||||
      negligent acts) or agreed to in writing, shall any Contributor be
 | 
			
		||||
      liable to You for damages, including any direct, indirect, special,
 | 
			
		||||
      incidental, or consequential damages of any character arising as a
 | 
			
		||||
      result of this License or out of the use or inability to use the
 | 
			
		||||
      Work (including but not limited to damages for loss of goodwill,
 | 
			
		||||
      work stoppage, computer failure or malfunction, or any and all
 | 
			
		||||
      other commercial damages or losses), even if such Contributor
 | 
			
		||||
      has been advised of the possibility of such damages.
 | 
			
		||||
 | 
			
		||||
   9. Accepting Warranty or Additional Liability. While redistributing
 | 
			
		||||
      the Work or Derivative Works thereof, You may choose to offer,
 | 
			
		||||
      and charge a fee for, acceptance of support, warranty, indemnity,
 | 
			
		||||
      or other liability obligations and/or rights consistent with this
 | 
			
		||||
      License. However, in accepting such obligations, You may act only
 | 
			
		||||
      on Your own behalf and on Your sole responsibility, not on behalf
 | 
			
		||||
      of any other Contributor, and only if You agree to indemnify,
 | 
			
		||||
      defend, and hold each Contributor harmless for any liability
 | 
			
		||||
      incurred by, or claims asserted against, such Contributor by reason
 | 
			
		||||
      of your accepting any such warranty or additional liability.
 | 
			
		||||
 | 
			
		||||
   END OF TERMS AND CONDITIONS
 | 
			
		||||
 | 
			
		||||
   APPENDIX: How to apply the Apache License to your work.
 | 
			
		||||
 | 
			
		||||
      To apply the Apache License to your work, attach the following
 | 
			
		||||
      boilerplate notice, with the fields enclosed by brackets "[]"
 | 
			
		||||
      replaced with your own identifying information. (Don't include
 | 
			
		||||
      the brackets!)  The text should be enclosed in the appropriate
 | 
			
		||||
      comment syntax for the file format. We also recommend that a
 | 
			
		||||
      file or class name and description of purpose be included on the
 | 
			
		||||
      same "printed page" as the copyright notice for easier
 | 
			
		||||
      identification within third-party archives.
 | 
			
		||||
 | 
			
		||||
   Copyright [yyyy] [name of copyright owner]
 | 
			
		||||
 | 
			
		||||
   Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
   you may not use this file except in compliance with the License.
 | 
			
		||||
   You may obtain a copy of the License at
 | 
			
		||||
 | 
			
		||||
       http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
 | 
			
		||||
   Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
   distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
   See the License for the specific language governing permissions and
 | 
			
		||||
   limitations under the License.
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										401
									
								
								README.md
									
									
									
									
									
								
							
							
						
						
									
										401
									
								
								README.md
									
									
									
									
									
								
							@@ -1,16 +1,29 @@
 | 
			
		||||
# ChatGPT-Plus
 | 
			
		||||
# GeekAI
 | 
			
		||||
### 本项目已经正式更名为 GeekAI,请大家及时更新代码克隆地址。
 | 
			
		||||
 | 
			
		||||
**ChatGPT-PLUS** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
 | 
			
		||||
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。主要有如下特性:
 | 
			
		||||
**GeekAI** 基于 AI 大语言模型 API 实现的 AI 助手全套开源解决方案,自带运营管理后台,开箱即用。集成了 OpenAI, Azure,
 | 
			
		||||
ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了 MidJourney 和 Stable Diffusion AI绘画功能。
 | 
			
		||||
 | 
			
		||||
* 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
 | 
			
		||||
* 基于 Websocket 实现,完美的打字机体验。
 | 
			
		||||
* 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
 | 
			
		||||
* 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
 | 
			
		||||
* 支持 MidJourney / Stable Diffusion AI 绘画集成,开箱即用。
 | 
			
		||||
* 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
 | 
			
		||||
* 已集成支付宝支付功能,支持多种会员套餐和点卡购买功能。
 | 
			
		||||
* 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI 绘画函数插件。
 | 
			
		||||
主要特性:
 | 
			
		||||
 | 
			
		||||
- 完整的开源系统,前端应用和后台管理系统皆可开箱即用。
 | 
			
		||||
- 基于 Websocket 实现,完美的打字机体验。
 | 
			
		||||
- 内置了各种预训练好的角色应用,比如小红书写手,英语翻译大师,苏格拉底,孔子,乔布斯,周报助手等。轻松满足你的各种聊天和应用需求。
 | 
			
		||||
- 支持 OPenAI,Azure,文心一言,讯飞星火,清华 ChatGLM等多个大语言模型。
 | 
			
		||||
- 支持 Suno 文生音乐
 | 
			
		||||
- 支持 MidJourney / Stable Diffusion AI 绘画集成,文生图,图生图,换脸,融图。开箱即用。
 | 
			
		||||
- 支持使用个人微信二维码作为充值收费的支付渠道,无需企业支付通道。
 | 
			
		||||
- 已集成支付宝支付功能,微信支付,支持多种会员套餐和点卡购买功能。
 | 
			
		||||
- 集成插件 API 功能,可结合大语言模型的 function 功能开发各种强大的插件,已内置实现了微博热搜,今日头条,今日早报和 AI
 | 
			
		||||
  绘画函数插件。
 | 
			
		||||
 | 
			
		||||
### 🚀 更多功能请查看 [GeekAI-PLUS](https://github.com/yangjian102621/geekai-plus)
 | 
			
		||||
 | 
			
		||||
- [x] 更友好的 UI 界面
 | 
			
		||||
- [x] 支持 Dall-E 文生图功能
 | 
			
		||||
- [x] 支持文生思维导图
 | 
			
		||||
- [x] 支持为模型绑定指定的 API KEY,支持为角色绑定指定的模型等功能
 | 
			
		||||
- [x] 支持网站 Logo 版权等信息的修改
 | 
			
		||||
 | 
			
		||||
## 功能截图
 | 
			
		||||
 | 
			
		||||
@@ -23,17 +36,28 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### MidJourney 专业绘画界面
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### Stable-Diffusion 专业绘画页面
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### 绘图作品展
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### AI应用列表
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### 会员充值
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### 自动调用函数插件
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
@@ -51,365 +75,50 @@ ChatGLM,讯飞星火,文心一言等多个平台的大语言模型。集成了
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
### 7. 体验地址
 | 
			
		||||
### 体验地址
 | 
			
		||||
 | 
			
		||||
> 免费体验地址:[https://ai.r9it.com/chat](https://ai.r9it.com/chat) <br/>
 | 
			
		||||
> **注意:请合法使用,禁止输出任何敏感、不友好或违规的内容!!!**
 | 
			
		||||
 | 
			
		||||
## 快速部署
 | 
			
		||||
 | 
			
		||||
请参考文档 [**GeekAI 快速部署**](https://ai.r9it.com/docs/install/)。
 | 
			
		||||
 | 
			
		||||
## 使用须知
 | 
			
		||||
 | 
			
		||||
1. 本项目基于 MIT 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
 | 
			
		||||
1. 本项目基于 Apache2.0 协议,免费开放全部源代码,可以作为个人学习使用或者商用。
 | 
			
		||||
2. 如需商用必须保留版权信息,请自觉遵守。确保合法合规使用,在运营过程中产生的一切任何后果自负,与作者无关。
 | 
			
		||||
 | 
			
		||||
## 项目介绍
 | 
			
		||||
 | 
			
		||||
这一套完整的系统,包括前端聊天应用和一个后台管理系统。系统有用户鉴权,你可以自己使用,也可以部署直接给 C 端用户提供
 | 
			
		||||
ChatGPT 的服务。
 | 
			
		||||
 | 
			
		||||
### 项目的技术架构
 | 
			
		||||
 | 
			
		||||
新版的系统前后端都进行大改动的重构,后端还是用的 Gin Web 框架,但是作者整合了 fx 自动注入框架,整个后端应用结构非常简洁,特别适合二次开发。
 | 
			
		||||
另外,数据存储用 MySQL 替换了 leveldb, 因为要对 C 端,后期会涉及到很多业务数据查询统计,leveldb 已经完全不够用了。
 | 
			
		||||
 | 
			
		||||
> Gin + fx + MySQL
 | 
			
		||||
 | 
			
		||||
3.0 版本之后会陆续添加其他语言的 API 实现,比如 PHP,Java 等。考虑到作者精力有限,api 目录已经添加了,有兴趣的同学自主去认领各自擅长的语言去实现。
 | 
			
		||||
 | 
			
		||||
前端的框架还是:
 | 
			
		||||
 | 
			
		||||
> Vue3 + Element-Plus
 | 
			
		||||
 | 
			
		||||
前后台的页面风格已经全部变了,几乎所有页面样式代码都重写了。逻辑代码还是沿用之前的,毕竟功能没有太大的变化。
 | 
			
		||||
 | 
			
		||||
此次重构改版主要是为了后面功能的扩展准备了。
 | 
			
		||||
 | 
			
		||||
新版本已经实现的功能如下:
 | 
			
		||||
 | 
			
		||||
1. 引入用户体系,新增用户注册和登录功能。
 | 
			
		||||
2. 聊天页面改版,实现了跟 ChatGPT 官方版本一致的聊天体验。
 | 
			
		||||
3. 创建会话的时候可以选择聊天角色和模型。
 | 
			
		||||
4. 新增聊天设置功能,用户可以导入自己的 API KEY
 | 
			
		||||
5. 保存聊天记录,支持聊天上下文。
 | 
			
		||||
6. 重构后台管理模块,更友好,扩展性更好的后台管理系统。
 | 
			
		||||
7. 引入 ip2region 组件,记录用户的登录IP和地址。
 | 
			
		||||
8. 支持会话搜索过滤。
 | 
			
		||||
9. 支持微信支付充值
 | 
			
		||||
 | 
			
		||||
## 项目地址
 | 
			
		||||
 | 
			
		||||
* Github 地址:https://github.com/yangjian102621/chatgpt-plus
 | 
			
		||||
* 码云地址:https://gitee.com/blackfox/chatgpt-plus
 | 
			
		||||
* Github 地址:https://github.com/yangjian102621/geekai
 | 
			
		||||
* 码云地址:https://gitee.com/blackfox/geekai
 | 
			
		||||
 | 
			
		||||
## 客户端下载
 | 
			
		||||
 | 
			
		||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/chatgpt-plus/releases/tag/v3.1.2
 | 
			
		||||
目前已经支持 Win/Linux/Mac/Android 客户端,下载地址为:https://github.com/yangjian102621/geekai/releases/tag/v3.1.2
 | 
			
		||||
 | 
			
		||||
## TODOLIST
 | 
			
		||||
 | 
			
		||||
* [x] 整合 Midjourney AI 绘画 API
 | 
			
		||||
* [x] 开发移动端聊天页面
 | 
			
		||||
* [x] 接入微信收款功能
 | 
			
		||||
* [x] 支持 ChatGPT 函数功能,通过函数实现插件
 | 
			
		||||
* [x] 开发桌面版应用
 | 
			
		||||
* [x] 开发手机 App 客户端
 | 
			
		||||
* [x] 支付宝支付功能
 | 
			
		||||
* [ ] 支持基于知识库的 AI 问答
 | 
			
		||||
* [ ] 会员推广功能
 | 
			
		||||
* [ ] 会员邀请注册推广功能
 | 
			
		||||
* [ ] 微信支付功能
 | 
			
		||||
 | 
			
		||||
## Docker 快速部署
 | 
			
		||||
## 项目文档
 | 
			
		||||
 | 
			
		||||
>
 | 
			
		||||
鉴于最新不少网友反馈在部署的时候遇到一些问题,大部分问题都是相同的,所以我这边做了一个视频教程 [五分钟部署自己的 ChatGPT 服务](https://www.bilibili.com/video/BV1H14y1B7Qw/)。
 | 
			
		||||
> 习惯看视频教程的朋友可以去看视频教程,视频的语速比较慢,建议 2 倍速观看。
 | 
			
		||||
最新的部署视频教程:[https://www.bilibili.com/video/BV1Cc411t7CX/](https://www.bilibili.com/video/BV1Cc411t7CX/)
 | 
			
		||||
 | 
			
		||||
V3.0.0 版本以后已经支持使用容器部署了,跳过所有的繁琐的环境准备,一条命令就可以轻松部署上线。
 | 
			
		||||
详细的部署和开发文档请参考 [**GeekAI 文档**](https://ai.r9it.com/docs/)。
 | 
			
		||||
 | 
			
		||||
### 1. 导入数据库
 | 
			
		||||
加微信进入微信讨论群可获取 **一键部署脚本(添加好友时请注明来自Github!!!)。**
 | 
			
		||||
 | 
			
		||||
首先我们需要创建一个 MySQL 容器,并导入初始数据库。
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
cd docker/mysql
 | 
			
		||||
# 创建 mysql 容器
 | 
			
		||||
docker-compose up -d
 | 
			
		||||
# 导入数据库
 | 
			
		||||
docker exec -i chatgpt-plus-mysql sh -c 'exec mysql -uroot -p12345678' < ../../database/chatgpt_plus-v3.1.8.sql
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
如果你本地已经安装了 MySQL 服务,那么你只需手动导入数据库即可。
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
# 连接数据库
 | 
			
		||||
mysql -u username -p password
 | 
			
		||||
# 导入数据库
 | 
			
		||||
source database/chatgpt_plus.sql
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 2. 修改配置文档
 | 
			
		||||
 | 
			
		||||
修改配置文档 `docker/conf/config.toml` 配置文档,修改代理地址和管理员密码:
 | 
			
		||||
 | 
			
		||||
```toml
 | 
			
		||||
Listen = "0.0.0.0:5678"
 | 
			
		||||
ProxyURL = "" # 如 http://127.0.0.1:7777
 | 
			
		||||
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local"
 | 
			
		||||
StaticDir = "./static" # 静态资源的目录
 | 
			
		||||
StaticUrl = "/static" # 静态资源访问 URL
 | 
			
		||||
AesEncryptKey = ""
 | 
			
		||||
WeChatBot = false # 是否启动微信机器人
 | 
			
		||||
 | 
			
		||||
[Session]
 | 
			
		||||
  SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
 | 
			
		||||
  MaxAge = 86400
 | 
			
		||||
 | 
			
		||||
[Manager]
 | 
			
		||||
  Username = "admin"
 | 
			
		||||
  Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
 | 
			
		||||
  
 | 
			
		||||
[Redis] # redis 配置信息
 | 
			
		||||
  Host = "localhost" 
 | 
			
		||||
  Port = 6379
 | 
			
		||||
  Password = ""
 | 
			
		||||
  DB = 0
 | 
			
		||||
  
 | 
			
		||||
[ApiConfig] # 微博热搜,今日头条等函数服务 API 配置,此为第三方插件服务,如需使用请联系作者开通
 | 
			
		||||
  ApiURL = ""
 | 
			
		||||
  AppId = ""
 | 
			
		||||
  Token = ""
 | 
			
		||||
 | 
			
		||||
[SmsConfig] # 阿里云短信服务配置
 | 
			
		||||
  AccessKey = ""
 | 
			
		||||
  AccessSecret = ""
 | 
			
		||||
  Product = "Dysmsapi"
 | 
			
		||||
  Domain = "dysmsapi.aliyuncs.com"
 | 
			
		||||
 | 
			
		||||
[ExtConfig] # MidJourney和微信机器人服务 API 配置,开通此功能需要配合 chatpgt-plus-exts 项目部署
 | 
			
		||||
  ApiURL = "" # 插件扩展 API 地址
 | 
			
		||||
  Token = "" # 这个 token 随便填,只要确保跟 chatgpt-plus-exts 项目的 token 一样就行 
 | 
			
		||||
  
 | 
			
		||||
[OSS] # OSS 配置,用于存储 MJ 绘画图片
 | 
			
		||||
   Active = "local" # 默认使用本地文件存储引擎
 | 
			
		||||
   [OSS.Local]
 | 
			
		||||
     BasePath = "./static/upload" # 本地文件上传根路径
 | 
			
		||||
     BaseURL = "http://localhost:5678/static/upload" # 本地上传文件根 URL 如果是线上,则直接设置为 /static/upload 即可
 | 
			
		||||
   [OSS.Minio]
 | 
			
		||||
     Endpoint = "" # 如 172.22.11.200:9000
 | 
			
		||||
     AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
 | 
			
		||||
     AccessSecret = ""
 | 
			
		||||
     Bucket = "chatgpt-plus" # 替换为你自己创建的 Bucket,注意要给 Bucket 设置公开的读权限,否则会出现图片无法显示。
 | 
			
		||||
     UseSSL = false
 | 
			
		||||
     Domain = "" # 地址必须是能够通过公网访问的,否则会出现图片无法显示。
 | 
			
		||||
   [OSS.QiNiu] # 七牛云 OSS 配置
 | 
			
		||||
       Zone = "z2" # 区域,z0:华东,z1: 华北,na0:北美,as0:新加坡
 | 
			
		||||
       AccessKey = ""
 | 
			
		||||
       AccessSecret = ""
 | 
			
		||||
       Bucket = ""
 | 
			
		||||
       Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
 | 
			
		||||
       
 | 
			
		||||
[MjConfig] # MidJourney AI 绘画配置
 | 
			
		||||
  Enabled = false # 是否启动 MidJourney 机器人服务
 | 
			
		||||
  UserToken = "" # 用户授权 Token
 | 
			
		||||
  BotToken = "" # Discord 机器人 Token
 | 
			
		||||
  GuildId = "" # 服务器 ID
 | 
			
		||||
  ChanelId = "" # 频道 ID
 | 
			
		||||
 | 
			
		||||
[SdConfig]
 | 
			
		||||
  Enabled = false # 是否启动 Stable Diffusion 机器人服务
 | 
			
		||||
  ApiURL = "http://172.22.11.200:7860" # stable-diffusion-webui API 地址
 | 
			
		||||
  ApiKey = "" # 如果开启了授权,这里需要配置授权的 ApiKey
 | 
			
		||||
  Txt2ImgJsonPath = "res/text2img.json" # 文生图的 API 请求报文 json 模板,允许自定义请求json报文,因为不同版本的 API 绘图的参数以及 fn_index 会不同。
 | 
			
		||||
  
 | 
			
		||||
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动
 | 
			
		||||
  Enabled = false # 是否启用 XXL JOB 服务
 | 
			
		||||
  ServerAddr = "http://172.22.11.47:8080/xxl-job-admin" # xxl-job-admin 管理地址
 | 
			
		||||
  ExecutorIp = "172.22.11.47" # 执行器 IP 地址
 | 
			
		||||
  ExecutorPort = "9999" # 执行器服务端口
 | 
			
		||||
  AccessToken = "xxl-job-api-token" # 执行器 API 通信 token
 | 
			
		||||
  RegistryKey = "chatgpt-plus" # 任务注册 key
 | 
			
		||||
 | 
			
		||||
[AlipayConfig]
 | 
			
		||||
  Enabled = false # 启用支付宝支付通道
 | 
			
		||||
  SandBox = false # 是否启用沙盒模式
 | 
			
		||||
  UserId = "2088721020750581" # 商户ID
 | 
			
		||||
  AppId = "9021000131658023" # App Id
 | 
			
		||||
  PrivateKey = "certs/alipay/privateKey.txt" # 应用私钥
 | 
			
		||||
  PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
 | 
			
		||||
  AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
 | 
			
		||||
  RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
 | 
			
		||||
  NotifyURL = "http://r9it.com:6004/api/payment/alipay/notify" # 支付异步回调地址
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
> 1. 如果你不知道如何获取 Discord 用户 Token 和 Bot Token
 | 
			
		||||
     请查参考 [Midjourney|如何集成到自己的平台](https://zhuanlan.zhihu.com/p/631079476)。
 | 
			
		||||
> 2. `Txt2ImgJsonPath`
 | 
			
		||||
     的默认用的是使用最广泛的 [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) 项目的
 | 
			
		||||
     API,如果你用的是其他版本,比如秋叶的懒人包部署的,那么请将对应的 text2img 的参数报文复制放在 `res/text2img.json`
 | 
			
		||||
     文件中即可。
 | 
			
		||||
 | 
			
		||||
修改 nginx 配置文档 `docker/conf/nginx/conf.d/chatgpt-plus.conf`,把后端转发的地址改成当前主机的内网 IP 地址。
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
 # 这里配置后端 API 的转发
 | 
			
		||||
location /api/ {
 | 
			
		||||
       proxy_http_version 1.1;
 | 
			
		||||
       proxy_connect_timeout 300s;
 | 
			
		||||
       proxy_read_timeout 300s;
 | 
			
		||||
       proxy_send_timeout 12s;
 | 
			
		||||
       proxy_set_header Host $host;
 | 
			
		||||
       proxy_set_header X-Real-IP $remote_addr;
 | 
			
		||||
       proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
 | 
			
		||||
       proxy_set_header Upgrade $http_upgrade;
 | 
			
		||||
       proxy_set_header Connection $connection_upgrade;
 | 
			
		||||
       proxy_pass http://172.28.173.76:6789; # 这里改成后端服务的内网 IP 地址
 | 
			
		||||
       
 | 
			
		||||
# 静态资源转发
 | 
			
		||||
location /static/ {
 | 
			
		||||
   proxy_pass http://172.22.11.47:5678; # 这里改成后端服务的内网 IP 地址
 | 
			
		||||
}
 | 
			
		||||
}
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 3. 启动应用
 | 
			
		||||
 | 
			
		||||
先修改 `docker/docker-compose.yaml` 文件中的镜像地址,改成最新的版本:
 | 
			
		||||
 | 
			
		||||
```yaml
 | 
			
		||||
version: '3'
 | 
			
		||||
services:
 | 
			
		||||
  # 后端 API 镜像
 | 
			
		||||
  chatgpt-plus-api:
 | 
			
		||||
    image: registry.cn-shenzhen.aliyuncs.com/geekmaster/chatgpt-plus-api:v3.1.8 #这里改成最新的 release 版本
 | 
			
		||||
    container_name: chatgpt-plus-api
 | 
			
		||||
    restart: always
 | 
			
		||||
    environment:
 | 
			
		||||
      - DEBUG=false
 | 
			
		||||
      - LOG_LEVEL=info
 | 
			
		||||
      - CONFIG_FILE=config.toml
 | 
			
		||||
    ports:
 | 
			
		||||
      - "5678:5678"
 | 
			
		||||
    volumes:
 | 
			
		||||
      - /usr/share/zoneinfo/Asia/Shanghai:/etc/localtime
 | 
			
		||||
      - ./conf/config.toml:/var/www/app/config.toml
 | 
			
		||||
      - ./logs:/var/www/app/logs
 | 
			
		||||
      - ./static:/var/www/app/static
 | 
			
		||||
 | 
			
		||||
  # 前端应用镜像
 | 
			
		||||
  chatgpt-plus-web:
 | 
			
		||||
    image: registry.cn-shenzhen.aliyuncs.com/geekmaster/chatgpt-plus-web:v3.1.8 #这里改成最新的 release 版本
 | 
			
		||||
    container_name: chatgpt-plus-web
 | 
			
		||||
    restart: always
 | 
			
		||||
    ports:
 | 
			
		||||
      - "8080:8080" # 这边是对外的端口,支持 8080,80和443
 | 
			
		||||
    volumes:
 | 
			
		||||
      - ./logs/nginx:/var/log/nginx
 | 
			
		||||
      - ./conf/nginx/conf.d:/etc/nginx/conf.d
 | 
			
		||||
      - ./conf/nginx/nginx.conf:/etc/nginx/nginx.conf
 | 
			
		||||
      - ./ssl:/etc/nginx/ssl
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
cd docker
 | 
			
		||||
docker-compose up -d
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
* 前端访问地址:http://localhost:8080/chat
 | 
			
		||||
* 后台管理地址:http://localhost:8080/admin
 | 
			
		||||
* 移动端地址:http://localhost:8080/mobile
 | 
			
		||||
 | 
			
		||||
> 注意:你得访问后台管理系统 http://localhost:8080/admin
 | 
			
		||||
> 输入你前面配置文档中设置的管理员用户名和密码登录。
 | 
			
		||||
> 然后进入 `API KEY 管理` 菜单,添加一个 OpenAI 的 API KEY 才可以正常开启 AI 对话。
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
最后进入前端聊天页面 [http://localhost:8080/chat](http://localhost:8080/chat)
 | 
			
		||||
你可以注册新用户,也可以使用系统默认有个账号:`18575670125/12345678` 登录聊天。
 | 
			
		||||
 | 
			
		||||
祝你使用愉快!!!
 | 
			
		||||
 | 
			
		||||
## 本地开发调试
 | 
			
		||||
 | 
			
		||||
本地开发同样要分别运行前端和后端程序。
 | 
			
		||||
 | 
			
		||||
### 运行后端程序
 | 
			
		||||
 | 
			
		||||
1. 同样你首先要 [导入数据库](#1-导入数据库)
 | 
			
		||||
2. 然后 [修改配置文档](#2-修改配置文档)
 | 
			
		||||
3. 运行后端程序:
 | 
			
		||||
 | 
			
		||||
    ```shell
 | 
			
		||||
    cd api 
 | 
			
		||||
    # 1. 先下载依赖
 | 
			
		||||
    go mod tidy
 | 
			
		||||
    # 2. 运行程序
 | 
			
		||||
    go run main.go
 | 
			
		||||
    # 如果你安装了 fresh 可以使用 fresh 实现热启动
 | 
			
		||||
    fresh -c fresh.conf
 | 
			
		||||
    ```
 | 
			
		||||
 | 
			
		||||
### 运行前端程序
 | 
			
		||||
 | 
			
		||||
同样先拷贝配置文档:
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
cd web
 | 
			
		||||
cp .env.production .env.development
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
编辑 `.env.development` 文件,修改后端 API 的访问路径:
 | 
			
		||||
 | 
			
		||||
```ini
 | 
			
		||||
VUE_APP_API_HOST=http://localhost:5678
 | 
			
		||||
VUE_APP_WS_HOST=ws://localhost:5678
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
配置好了之后就可以运行前端应用了:
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
# 安装依赖
 | 
			
		||||
npm install
 | 
			
		||||
# 运行
 | 
			
		||||
npm run dev
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
* 前端页面:http://localhost:8888/chat
 | 
			
		||||
* 后台管理页面:http://localhost:8888/admin
 | 
			
		||||
 | 
			
		||||
## 项目打包
 | 
			
		||||
 | 
			
		||||
由于本项目是采用异构开发的方式,所项目打包分成两步:首先编译后端程序,然后再打包前端应用。
 | 
			
		||||
 | 
			
		||||
### 打包前端
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
cd web
 | 
			
		||||
npm run build
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 打包后端
 | 
			
		||||
 | 
			
		||||
你可以根据个人需求将项目打包成 windows/linux/darwin 平台项目。
 | 
			
		||||
 | 
			
		||||
```shell
 | 
			
		||||
cd api
 | 
			
		||||
# for all platforms
 | 
			
		||||
make clean all
 | 
			
		||||
# for linux only
 | 
			
		||||
make clean linux
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
打包后的可执行文件在 `bin` 目录下。
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
## 参与贡献
 | 
			
		||||
 | 
			
		||||
个人的力量始终有限,任何形式的贡献都是欢迎的,包括但不限于贡献代码,优化文档,提交 issue 和 PR 等。
 | 
			
		||||
 | 
			
		||||
如果有兴趣的话,也可以加微信进入微信讨论群(**添加好友时请注明来自Github!!!**)。
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
#### 特此声明:不接受在微信或者微信群给开发者提 Bug,有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合!
 | 
			
		||||
#### 特此声明:由于个人时间有限,不接受在微信或者微信群给开发者提 Bug,有问题或者优化建议请提交 Issue 和 PR。非常感谢您的配合!
 | 
			
		||||
 | 
			
		||||
### Commit 类型
 | 
			
		||||
 | 
			
		||||
@@ -425,10 +134,6 @@ make clean linux
 | 
			
		||||
 | 
			
		||||
如果你觉得这个项目对你有帮助,并且情况允许的话,可以请作者喝杯咖啡,非常感谢你的支持~
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||

 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 | 
			
		||||

 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										1
									
								
								api/.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								api/.gitignore
									
									
									
									
										vendored
									
									
								
							@@ -18,4 +18,3 @@ data
 | 
			
		||||
config.toml
 | 
			
		||||
static/upload 
 | 
			
		||||
storage.json
 | 
			
		||||
certs/alipay/*
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										21
									
								
								api/Makefile
									
									
									
									
									
								
							
							
						
						
									
										21
									
								
								api/Makefile
									
									
									
									
									
								
							@@ -1,19 +1,14 @@
 | 
			
		||||
SHELL=/usr/bin/env bash
 | 
			
		||||
NAME := chatgpt-plus
 | 
			
		||||
all: window linux darwin
 | 
			
		||||
NAME := geekai
 | 
			
		||||
all: amd64 arm64
 | 
			
		||||
 | 
			
		||||
amd64:
 | 
			
		||||
	CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-linux main.go
 | 
			
		||||
.PHONY: amd64
 | 
			
		||||
 | 
			
		||||
window:
 | 
			
		||||
	CGO_ENABLED=0 GOOS=windows GOARCH=amd64 go build -o bin/$(NAME)-amd64.exe main.go
 | 
			
		||||
.PHONY: window
 | 
			
		||||
 | 
			
		||||
linux:
 | 
			
		||||
	CGO_ENABLED=0 GOOS=linux GOARCH=amd64 go build -o bin/$(NAME)-amd64-linux main.go
 | 
			
		||||
.PHONY: linux
 | 
			
		||||
 | 
			
		||||
darwin:
 | 
			
		||||
	CGO_ENABLED=0 GOOS=darwin GOARCH=amd64 go build -o bin/$(NAME)-amd64-darwin main.go
 | 
			
		||||
.PHONY: darwin
 | 
			
		||||
arm64:
 | 
			
		||||
	CGO_ENABLED=0 GOOS=linux GOARCH=arm64 GOARM=7 go build -o bin/$(NAME)-linux main.go
 | 
			
		||||
.PHONY: arm64
 | 
			
		||||
 | 
			
		||||
clean:
 | 
			
		||||
	rm -rf bin/$(NAME)-*
 | 
			
		||||
 
 | 
			
		||||
@@ -1,6 +1,6 @@
 | 
			
		||||
Listen = "0.0.0.0:5678"
 | 
			
		||||
ProxyURL = "" # 如 http://127.0.0.1:7777
 | 
			
		||||
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8&parseTime=True&loc=Local"
 | 
			
		||||
MysqlDns = "root:12345678@tcp(172.22.11.200:3307)/chatgpt_plus?charset=utf8mb4&collation=utf8mb4_unicode_ci&parseTime=True&loc=Local"
 | 
			
		||||
StaticDir = "./static" # 静态资源的目录
 | 
			
		||||
StaticUrl = "/static" # 静态资源访问 URL
 | 
			
		||||
AesEncryptKey = ""
 | 
			
		||||
@@ -10,10 +10,6 @@ WeChatBot = false
 | 
			
		||||
  SecretKey = "azyehq3ivunjhbntz78isj00i4hz2mt9xtddysfucxakadq4qbfrt0b7q3lnvg80" # 注意:这个是 JWT Token 授权密钥,生产环境请务必更换
 | 
			
		||||
  MaxAge = 86400
 | 
			
		||||
 | 
			
		||||
[Manager]
 | 
			
		||||
  Username = "admin"
 | 
			
		||||
  Password = "admin123" # 如果是生产环境的话,这里管理员的密码记得修改
 | 
			
		||||
 | 
			
		||||
[Redis] # redis 配置信息
 | 
			
		||||
  Host = "localhost"
 | 
			
		||||
  Port = 6379
 | 
			
		||||
@@ -25,23 +21,28 @@ WeChatBot = false
 | 
			
		||||
  AppId = ""
 | 
			
		||||
  Token = ""
 | 
			
		||||
 | 
			
		||||
[SmsConfig] # 阿里云短信服务配置
 | 
			
		||||
  AccessKey = ""
 | 
			
		||||
  AccessSecret = ""
 | 
			
		||||
  Product = "Dysmsapi"
 | 
			
		||||
  Domain = "dysmsapi.aliyuncs.com"
 | 
			
		||||
  Sign = ""
 | 
			
		||||
  CodeTempId = ""
 | 
			
		||||
 | 
			
		||||
[ExtConfig] # MidJourney和微信机器人服务 API 配置,开通此功能需要配合 chatpgt-plus-exts 项目部署
 | 
			
		||||
  ApiURL = "" # 插件扩展 API 地址
 | 
			
		||||
  Token = "" # 这个 token 随便填,只要确保跟 chatgpt-plus-exts 项目的 token 一样就行
 | 
			
		||||
[SMS] # Sms 配置,用于发送短信
 | 
			
		||||
   Active = "Ali" # 当前启用的短信服务,默认使用阿里云
 | 
			
		||||
   [SMS.Bao]
 | 
			
		||||
      Username = ""
 | 
			
		||||
      Password = ""
 | 
			
		||||
      Domain = "api.smsbao.com"
 | 
			
		||||
      Sign = "【极客学长】"
 | 
			
		||||
      CodeTemplate = "您的验证码是{code}。5分钟有效,若非本人操作,请忽略本短信。"
 | 
			
		||||
   [SMS.Ali]
 | 
			
		||||
      AccessKey = ""
 | 
			
		||||
      AccessSecret = ""
 | 
			
		||||
      Product = "Dysmsapi"
 | 
			
		||||
      Domain = "dysmsapi.aliyuncs.com"
 | 
			
		||||
      Sign = ""
 | 
			
		||||
      CodeTempId = ""
 | 
			
		||||
 | 
			
		||||
[OSS] # OSS 配置,用于存储 MJ 绘画图片
 | 
			
		||||
   Active = "local" # 默认使用本地文件存储引擎
 | 
			
		||||
   [OSS.Local]
 | 
			
		||||
     BasePath = "./static/upload" # 本地文件上传根路径
 | 
			
		||||
     BaseURL = "http://localhost:5678/static/upload" # 本地上传文件根 URL 如果是线上,则直接设置为 /static/upload 即可
 | 
			
		||||
     BaseURL = "http://localhost:5678/static/upload" # 本地上传文件前缀 URL,线上需要把 localhost 替换成自己的实际域名或者IP
 | 
			
		||||
   [OSS.Minio]
 | 
			
		||||
     Endpoint = "" # 如 172.22.11.200:9000
 | 
			
		||||
     AccessKey = "" # 自己去 Minio 控制台去创建一个 Access Key
 | 
			
		||||
@@ -55,19 +56,30 @@ WeChatBot = false
 | 
			
		||||
       AccessSecret = ""
 | 
			
		||||
       Bucket = ""
 | 
			
		||||
       Domain = "" # OSS Bucket 所绑定的域名,如 https://img.r9it.com
 | 
			
		||||
   [OSS.AliYun]
 | 
			
		||||
       Endpoint = "oss-cn-hangzhou.aliyuncs.com"
 | 
			
		||||
       AccessKey = ""
 | 
			
		||||
       AccessSecret = ""
 | 
			
		||||
       Bucket = "chatgpt-plus"
 | 
			
		||||
       SubDir = ""
 | 
			
		||||
       Domain = ""
 | 
			
		||||
 | 
			
		||||
[MjConfig]
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  UserToken = ""
 | 
			
		||||
  BotToken = ""
 | 
			
		||||
  GuildId = ""
 | 
			
		||||
  ChanelId = ""
 | 
			
		||||
[[MjProxyConfigs]]
 | 
			
		||||
  Enabled = true
 | 
			
		||||
  ApiURL = "http://midjourney-proxy:8082"
 | 
			
		||||
  ApiKey = "sk-geekmaster"
 | 
			
		||||
 | 
			
		||||
[SdConfig]
 | 
			
		||||
[[MjPlusConfigs]]
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  ApiURL = "http://172.22.11.200:7860"
 | 
			
		||||
  ApiURL = "https://api.chat-plus.net"
 | 
			
		||||
  Mode = "fast" # MJ 绘画模式,可选值 relax/fast/turbo
 | 
			
		||||
  ApiKey = "sk-xxx"
 | 
			
		||||
 | 
			
		||||
[[SdConfigs]]
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  ApiURL = ""
 | 
			
		||||
  ApiKey = ""
 | 
			
		||||
  Txt2ImgJsonPath = "res/text2img.json"
 | 
			
		||||
  Txt2ImgJsonPath = "res/sd/text2img.json"
 | 
			
		||||
 | 
			
		||||
[XXLConfig] # xxl-job 配置,需要你部署 XXL-JOB 定时任务工具,用来定期清理未支付订单和清理过期 VIP,如果你没有启用支付服务,则该服务也无需启动
 | 
			
		||||
  Enabled = false # 是否启用 XXL JOB 服务
 | 
			
		||||
@@ -86,4 +98,28 @@ WeChatBot = false
 | 
			
		||||
  PublicKey = "certs/alipay/appPublicCert.crt" # 应用公钥证书
 | 
			
		||||
  AlipayPublicKey = "certs/alipay/alipayPublicCert.crt" # 支付宝公钥证书
 | 
			
		||||
  RootCert = "certs/alipay/alipayRootCert.crt" # 支付宝根证书
 | 
			
		||||
  NotifyURL = "http://r9it.com:6004/api/payment/alipay/notify" # 支付异步回调地址
 | 
			
		||||
  NotifyURL = "https://ai.r9it.com/api/payment/alipay/notify" # 支付异步回调地址
 | 
			
		||||
 | 
			
		||||
[HuPiPayConfig]
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  Name = "wechat"
 | 
			
		||||
  AppId = ""
 | 
			
		||||
  AppSecret = ""
 | 
			
		||||
  ApiURL = "https://api.xunhupay.com"
 | 
			
		||||
  NotifyURL = "https://ai.r9it.com/api/payment/hupipay/notify"
 | 
			
		||||
 | 
			
		||||
[SmtpConfig] # 注意,阿里云服务器禁用了25号端口,请使用 465 端口,并开启 TLS 连接
 | 
			
		||||
  UseTls = false
 | 
			
		||||
  Host = "smtp.163.com"
 | 
			
		||||
  Port = 25
 | 
			
		||||
  AppName = "极客学长"
 | 
			
		||||
  From = "test@163.com" # 发件邮箱人地址
 | 
			
		||||
  Password = "" #邮箱 stmp 服务授权码
 | 
			
		||||
 | 
			
		||||
[JPayConfig] # PayJs 支付配置
 | 
			
		||||
  Enabled = false
 | 
			
		||||
  Name = "wechat" # 请不要改动
 | 
			
		||||
  AppId = "" # 商户 ID
 | 
			
		||||
  PrivateKey = "" # 秘钥
 | 
			
		||||
  ApiURL = "https://payjs.cn"
 | 
			
		||||
  NotifyURL = "https://ai.r9it.com/api/payment/payjs/notify" # 异步回调地址,域名改成你自己的
 | 
			
		||||
@@ -1,19 +1,31 @@
 | 
			
		||||
package core
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/fun"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"github.com/nfnt/resize"
 | 
			
		||||
	"golang.org/x/image/webp"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"image"
 | 
			
		||||
	"image/jpeg"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"os"
 | 
			
		||||
	"runtime/debug"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -23,31 +35,28 @@ type AppServer struct {
 | 
			
		||||
	Debug        bool
 | 
			
		||||
	Config       *types.AppConfig
 | 
			
		||||
	Engine       *gin.Engine
 | 
			
		||||
	ChatContexts *types.LMap[string, []interface{}] // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
	ChatContexts *types.LMap[string, []types.Message] // 聊天上下文 Map [chatId] => []Message
 | 
			
		||||
 | 
			
		||||
	ChatConfig *types.ChatConfig   // chat config cache
 | 
			
		||||
	SysConfig  *types.SystemConfig // system config cache
 | 
			
		||||
	SysConfig *types.SystemConfig // system config cache
 | 
			
		||||
 | 
			
		||||
	// 保存 Websocket 会话 UserId, 每个 UserId 只能连接一次
 | 
			
		||||
	// 防止第三方直接连接 socket 调用 OpenAI API
 | 
			
		||||
	ChatSession   *types.LMap[string, *types.ChatSession] //map[sessionId]UserId
 | 
			
		||||
	ChatClients   *types.LMap[string, *types.WsClient]    // map[sessionId]Websocket 连接集合
 | 
			
		||||
	ReqCancelFunc *types.LMap[string, context.CancelFunc] // HttpClient 请求取消 handle function
 | 
			
		||||
	Functions     map[string]fun.Function
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewServer(appConfig *types.AppConfig, functions map[string]fun.Function) *AppServer {
 | 
			
		||||
func NewServer(appConfig *types.AppConfig) *AppServer {
 | 
			
		||||
	gin.SetMode(gin.ReleaseMode)
 | 
			
		||||
	gin.DefaultWriter = io.Discard
 | 
			
		||||
	return &AppServer{
 | 
			
		||||
		Debug:         false,
 | 
			
		||||
		Config:        appConfig,
 | 
			
		||||
		Engine:        gin.Default(),
 | 
			
		||||
		ChatContexts:  types.NewLMap[string, []interface{}](),
 | 
			
		||||
		ChatContexts:  types.NewLMap[string, []types.Message](),
 | 
			
		||||
		ChatSession:   types.NewLMap[string, *types.ChatSession](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ReqCancelFunc: types.NewLMap[string, context.CancelFunc](),
 | 
			
		||||
		Functions:     functions,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -57,30 +66,22 @@ func (s *AppServer) Init(debug bool, client *redis.Client) {
 | 
			
		||||
		logger.Info("Enabled debug mode")
 | 
			
		||||
	}
 | 
			
		||||
	s.Engine.Use(corsMiddleware())
 | 
			
		||||
	s.Engine.Use(staticResourceMiddleware())
 | 
			
		||||
	s.Engine.Use(authorizeMiddleware(s, client))
 | 
			
		||||
	s.Engine.Use(parameterHandlerMiddleware())
 | 
			
		||||
	s.Engine.Use(errorHandler)
 | 
			
		||||
	// 添加静态资源访问
 | 
			
		||||
	s.Engine.Static("/static", s.Config.StaticDir)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *AppServer) Run(db *gorm.DB) error {
 | 
			
		||||
	// load chat config from database
 | 
			
		||||
	var chatConfig model.Config
 | 
			
		||||
	res := db.Where("marker", "chat").First(&chatConfig)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return res.Error
 | 
			
		||||
	}
 | 
			
		||||
	err := utils.JsonDecode(chatConfig.Config, &s.ChatConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	// load system configs
 | 
			
		||||
	var sysConfig model.Config
 | 
			
		||||
	res = db.Where("marker", "system").First(&sysConfig)
 | 
			
		||||
	res := db.Where("marker", "system").First(&sysConfig)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return res.Error
 | 
			
		||||
	}
 | 
			
		||||
	err = utils.JsonDecode(sysConfig.Config, &s.SysConfig)
 | 
			
		||||
	err := utils.JsonDecode(sysConfig.Config, &s.SysConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
@@ -138,72 +139,64 @@ func corsMiddleware() gin.HandlerFunc {
 | 
			
		||||
// 用户授权验证
 | 
			
		||||
func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		if c.Request.URL.Path == "/api/user/login" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/admin/login" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/user/register" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/reward/notify" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/mj/notify" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/chat/history" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/chat/detail" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/role/list" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/mj/jobs" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/mj/proxy" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/sd/jobs" ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
 | 
			
		||||
			strings.HasPrefix(c.Request.URL.Path, "/static/") ||
 | 
			
		||||
			c.Request.URL.Path == "/api/admin/config/get" {
 | 
			
		||||
			c.Next()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var tokenString string
 | 
			
		||||
		if strings.Contains(c.Request.URL.Path, "/api/admin/") { // 后台管理 API
 | 
			
		||||
		isAdminApi := strings.Contains(c.Request.URL.Path, "/api/admin/")
 | 
			
		||||
		if isAdminApi { // 后台管理 API
 | 
			
		||||
			tokenString = c.GetHeader(types.AdminAuthHeader)
 | 
			
		||||
		} else if c.Request.URL.Path == "/api/chat/new" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/mj/client" ||
 | 
			
		||||
			c.Request.URL.Path == "/api/sd/client" {
 | 
			
		||||
		} else if c.Request.URL.Path == "/api/chat/new" {
 | 
			
		||||
			tokenString = c.Query("token")
 | 
			
		||||
		} else {
 | 
			
		||||
			tokenString = c.GetHeader(types.UserAuthHeader)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if tokenString == "" {
 | 
			
		||||
			resp.ERROR(c, "You should put Authorization in request headers")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
			if needLogin(c) {
 | 
			
		||||
				resp.ERROR(c, "You should put Authorization in request headers")
 | 
			
		||||
				c.Abort()
 | 
			
		||||
				return
 | 
			
		||||
			} else { // 直接放行
 | 
			
		||||
				c.Next()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
 | 
			
		||||
			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 | 
			
		||||
			if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok && needLogin(c) {
 | 
			
		||||
				return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
 | 
			
		||||
			}
 | 
			
		||||
			if isAdminApi {
 | 
			
		||||
				return []byte(s.Config.AdminSession.SecretKey), nil
 | 
			
		||||
			} else {
 | 
			
		||||
				return []byte(s.Config.Session.SecretKey), nil
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			return []byte(s.Config.Session.SecretKey), nil
 | 
			
		||||
		})
 | 
			
		||||
 | 
			
		||||
		if err != nil {
 | 
			
		||||
		if err != nil && needLogin(c) {
 | 
			
		||||
			resp.NotAuth(c, fmt.Sprintf("Error with parse auth token: %v", err))
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		claims, ok := token.Claims.(jwt.MapClaims)
 | 
			
		||||
		if !ok || !token.Valid {
 | 
			
		||||
		if !ok || !token.Valid && needLogin(c) {
 | 
			
		||||
			resp.NotAuth(c, "Token is invalid")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
 | 
			
		||||
		if expr > 0 && int64(expr) < time.Now().Unix() {
 | 
			
		||||
		if expr > 0 && int64(expr) < time.Now().Unix() && needLogin(c) {
 | 
			
		||||
			resp.NotAuth(c, "Token is expired")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		key := fmt.Sprintf("users/%v", claims["user_id"])
 | 
			
		||||
		if _, err := client.Get(context.Background(), key).Result(); err != nil {
 | 
			
		||||
		if isAdminApi {
 | 
			
		||||
			key = fmt.Sprintf("admin/%v", claims["user_id"])
 | 
			
		||||
		}
 | 
			
		||||
		if _, err := client.Get(context.Background(), key).Result(); err != nil && needLogin(c) {
 | 
			
		||||
			resp.NotAuth(c, "Token is not found in redis")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
			return
 | 
			
		||||
@@ -211,3 +204,173 @@ func authorizeMiddleware(s *AppServer, client *redis.Client) gin.HandlerFunc {
 | 
			
		||||
		c.Set(types.LoginUserID, claims["user_id"])
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func needLogin(c *gin.Context) bool {
 | 
			
		||||
	if c.Request.URL.Path == "/api/user/login" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/user/logout" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/user/resetPass" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/admin/login" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/admin/logout" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/admin/login/captcha" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/user/register" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/user/session" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/chat/history" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/chat/detail" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/chat/list" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/role/list" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/model/list" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/mj/imgWall" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/mj/client" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/mj/notify" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/invite/hits" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/sd/imgWall" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/sd/client" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/dall/imgWall" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/dall/client" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/config/get" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/product/list" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/menu/list" ||
 | 
			
		||||
		c.Request.URL.Path == "/api/markMap/client" ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/test") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/function/") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/sms/") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/captcha/") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/api/payment/") ||
 | 
			
		||||
		strings.HasPrefix(c.Request.URL.Path, "/static/") {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 统一参数处理
 | 
			
		||||
func parameterHandlerMiddleware() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
		// GET 参数处理
 | 
			
		||||
		params := c.Request.URL.Query()
 | 
			
		||||
		for key, values := range params {
 | 
			
		||||
			for i, value := range values {
 | 
			
		||||
				params[key][i] = strings.TrimSpace(value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// update get parameters
 | 
			
		||||
		c.Request.URL.RawQuery = params.Encode()
 | 
			
		||||
		// skip file upload requests
 | 
			
		||||
		contentType := c.Request.Header.Get("Content-Type")
 | 
			
		||||
		if strings.Contains(contentType, "multipart/form-data") {
 | 
			
		||||
			c.Next()
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if strings.Contains(contentType, "application/json") {
 | 
			
		||||
			// process POST JSON request body
 | 
			
		||||
			bodyBytes, err := io.ReadAll(c.Request.Body)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.Next()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 还原请求体
 | 
			
		||||
			c.Request.Body = io.NopCloser(bytes.NewBuffer(bodyBytes))
 | 
			
		||||
			// 将请求体解析为 JSON
 | 
			
		||||
			var jsonData map[string]interface{}
 | 
			
		||||
			if err := c.ShouldBindJSON(&jsonData); err != nil {
 | 
			
		||||
				c.Next()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 对 JSON 数据中的字符串值去除两端空格
 | 
			
		||||
			trimJSONStrings(jsonData)
 | 
			
		||||
			// 更新请求体
 | 
			
		||||
			c.Request.Body = io.NopCloser(bytes.NewBufferString(utils.JsonEncode(jsonData)))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 递归对 JSON 数据中的字符串值去除两端空格
 | 
			
		||||
func trimJSONStrings(data interface{}) {
 | 
			
		||||
	switch v := data.(type) {
 | 
			
		||||
	case map[string]interface{}:
 | 
			
		||||
		for key, value := range v {
 | 
			
		||||
			switch valueType := value.(type) {
 | 
			
		||||
			case string:
 | 
			
		||||
				v[key] = strings.TrimSpace(valueType)
 | 
			
		||||
			case map[string]interface{}, []interface{}:
 | 
			
		||||
				trimJSONStrings(value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	case []interface{}:
 | 
			
		||||
		for i, value := range v {
 | 
			
		||||
			switch valueType := value.(type) {
 | 
			
		||||
			case string:
 | 
			
		||||
				v[i] = strings.TrimSpace(valueType)
 | 
			
		||||
			case map[string]interface{}, []interface{}:
 | 
			
		||||
				trimJSONStrings(value)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 静态资源中间件
 | 
			
		||||
func staticResourceMiddleware() gin.HandlerFunc {
 | 
			
		||||
	return func(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
		url := c.Request.URL.String()
 | 
			
		||||
		// 拦截生成缩略图请求
 | 
			
		||||
		if strings.HasPrefix(url, "/static/") && strings.Contains(url, "?imageView2") {
 | 
			
		||||
			r := strings.SplitAfter(url, "imageView2")
 | 
			
		||||
			size := strings.Split(r[1], "/")
 | 
			
		||||
			if len(size) != 8 {
 | 
			
		||||
				c.String(http.StatusNotFound, "invalid thumb args")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			with := utils.IntValue(size[3], 0)
 | 
			
		||||
			height := utils.IntValue(size[5], 0)
 | 
			
		||||
			quality := utils.IntValue(size[7], 75)
 | 
			
		||||
 | 
			
		||||
			// 打开图片文件
 | 
			
		||||
			filePath := strings.TrimLeft(c.Request.URL.Path, "/")
 | 
			
		||||
			file, err := os.Open(filePath)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.String(http.StatusNotFound, "Image not found")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			defer file.Close()
 | 
			
		||||
 | 
			
		||||
			// 解码图片
 | 
			
		||||
			img, _, err := image.Decode(file)
 | 
			
		||||
			// for .webp image
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				img, err = webp.Decode(file)
 | 
			
		||||
			}
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				c.String(http.StatusInternalServerError, "Error decoding image")
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var newImg image.Image
 | 
			
		||||
			if height == 0 || with == 0 {
 | 
			
		||||
				// 固定宽度,高度自适应
 | 
			
		||||
				newImg = resize.Resize(uint(with), uint(height), img, resize.Lanczos3)
 | 
			
		||||
			} else {
 | 
			
		||||
				// 生成缩略图
 | 
			
		||||
				newImg = resize.Thumbnail(uint(with), uint(height), img, resize.Lanczos3)
 | 
			
		||||
			}
 | 
			
		||||
			var buffer bytes.Buffer
 | 
			
		||||
			err = jpeg.Encode(&buffer, newImg, &jpeg.Options{Quality: quality})
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				c.String(http.StatusInternalServerError, err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 设置图片缓存有效期为一年 (365天)
 | 
			
		||||
			c.Header("Cache-Control", "max-age=31536000, public")
 | 
			
		||||
			// 直接输出图像数据流
 | 
			
		||||
			c.Data(http.StatusOK, "image/jpeg", buffer.Bytes())
 | 
			
		||||
			c.Abort() // 中断请求
 | 
			
		||||
		}
 | 
			
		||||
		c.Next()
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,10 +1,17 @@
 | 
			
		||||
package core
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"os"
 | 
			
		||||
 | 
			
		||||
	"github.com/BurntSushi/toml"
 | 
			
		||||
@@ -14,18 +21,16 @@ var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
func NewDefaultConfig() *types.AppConfig {
 | 
			
		||||
	return &types.AppConfig{
 | 
			
		||||
		Listen:        "0.0.0.0:5678",
 | 
			
		||||
		ProxyURL:      "",
 | 
			
		||||
		Manager:       types.Manager{Username: "admin", Password: "admin123"},
 | 
			
		||||
		StaticDir:     "./static",
 | 
			
		||||
		StaticUrl:     "http://localhost/5678/static",
 | 
			
		||||
		Redis:         types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
 | 
			
		||||
		AesEncryptKey: utils.RandString(24),
 | 
			
		||||
		Listen:    "0.0.0.0:5678",
 | 
			
		||||
		ProxyURL:  "",
 | 
			
		||||
		StaticDir: "./static",
 | 
			
		||||
		StaticUrl: "http://localhost/5678/static",
 | 
			
		||||
		Redis:     types.RedisConfig{Host: "localhost", Port: 6379, Password: ""},
 | 
			
		||||
		Session: types.Session{
 | 
			
		||||
			SecretKey: utils.RandString(64),
 | 
			
		||||
			MaxAge:    86400,
 | 
			
		||||
		},
 | 
			
		||||
		ApiConfig: types.ChatPlusApiConfig{},
 | 
			
		||||
		ApiConfig: types.ApiConfig{},
 | 
			
		||||
		OSS: types.OSSConfig{
 | 
			
		||||
			Active: "local",
 | 
			
		||||
			Local: types.LocalStorageConfig{
 | 
			
		||||
@@ -33,8 +38,6 @@ func NewDefaultConfig() *types.AppConfig {
 | 
			
		||||
				BasePath: "./static/upload",
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
		MjConfig:     types.MidJourneyConfig{Enabled: false},
 | 
			
		||||
		SdConfig:     types.StableDiffusionConfig{Enabled: false, Txt2ImgJsonPath: "res/text2img.json"},
 | 
			
		||||
		WeChatBot:    false,
 | 
			
		||||
		AlipayConfig: types.AlipayConfig{Enabled: false, SandBox: false},
 | 
			
		||||
	}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
// ApiRequest API 请求实体
 | 
			
		||||
type ApiRequest struct {
 | 
			
		||||
	Model       string        `json:"model,omitempty"` // 兼容百度文心一言
 | 
			
		||||
@@ -8,7 +15,13 @@ type ApiRequest struct {
 | 
			
		||||
	Stream      bool          `json:"stream"`
 | 
			
		||||
	Messages    []interface{} `json:"messages,omitempty"`
 | 
			
		||||
	Prompt      []interface{} `json:"prompt,omitempty"` // 兼容 ChatGLM
 | 
			
		||||
	Functions   []Function    `json:"functions,omitempty"`
 | 
			
		||||
	Tools       []Tool        `json:"tools,omitempty"`
 | 
			
		||||
	Functions   []interface{} `json:"functions,omitempty"` // 兼容中转平台
 | 
			
		||||
 | 
			
		||||
	ToolChoice string `json:"tool_choice,omitempty"`
 | 
			
		||||
 | 
			
		||||
	Input      map[string]interface{} `json:"input,omitempty"`      //兼容阿里通义千问
 | 
			
		||||
	Parameters map[string]interface{} `json:"parameters,omitempty"` //兼容阿里通义千问
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Message struct {
 | 
			
		||||
@@ -27,10 +40,14 @@ type ChoiceItem struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Delta struct {
 | 
			
		||||
	Role         string       `json:"role"`
 | 
			
		||||
	Name         string       `json:"name"`
 | 
			
		||||
	Content      interface{}  `json:"content"`
 | 
			
		||||
	FunctionCall FunctionCall `json:"function_call,omitempty"`
 | 
			
		||||
	Role         string      `json:"role"`
 | 
			
		||||
	Name         string      `json:"name"`
 | 
			
		||||
	Content      interface{} `json:"content"`
 | 
			
		||||
	ToolCalls    []ToolCall  `json:"tool_calls,omitempty"`
 | 
			
		||||
	FunctionCall struct {
 | 
			
		||||
		Name      string `json:"name,omitempty"`
 | 
			
		||||
		Arguments string `json:"arguments,omitempty"`
 | 
			
		||||
	} `json:"function_call,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatSession 聊天会话对象
 | 
			
		||||
@@ -44,10 +61,15 @@ type ChatSession struct {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatModel struct {
 | 
			
		||||
	Id       uint     `json:"id"`
 | 
			
		||||
	Platform Platform `json:"platform"`
 | 
			
		||||
	Value    string   `json:"value"`
 | 
			
		||||
	Weight   int      `json:"weight"`
 | 
			
		||||
	Id          uint     `json:"id"`
 | 
			
		||||
	Platform    Platform `json:"platform"`
 | 
			
		||||
	Name        string   `json:"name"`
 | 
			
		||||
	Value       string   `json:"value"`
 | 
			
		||||
	Power       int      `json:"power"`
 | 
			
		||||
	MaxTokens   int      `json:"max_tokens"`  // 最大响应长度
 | 
			
		||||
	MaxContext  int      `json:"max_context"` // 最大上下文长度
 | 
			
		||||
	Temperature float32  `json:"temperature"` // 模型温度
 | 
			
		||||
	KeyId       int      `json:"key_id"`      // 绑定 API KEY
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ApiError struct {
 | 
			
		||||
@@ -61,17 +83,37 @@ type ApiError struct {
 | 
			
		||||
 | 
			
		||||
const PromptMsg = "prompt" // prompt message
 | 
			
		||||
const ReplyMsg = "reply"   // reply message
 | 
			
		||||
const MjMsg = "mj"
 | 
			
		||||
 | 
			
		||||
var ModelToTokens = map[string]int{
 | 
			
		||||
	"gpt-3.5-turbo":     4096,
 | 
			
		||||
	"gpt-3.5-turbo-16k": 16384,
 | 
			
		||||
	"gpt-4":             8192,
 | 
			
		||||
	"gpt-4-32k":         32768,
 | 
			
		||||
	"chatglm_pro":       32768, // 清华智普
 | 
			
		||||
	"chatglm_std":       16384,
 | 
			
		||||
	"chatglm_lite":      4096,
 | 
			
		||||
	"ernie_bot_turbo":   8192, // 文心一言
 | 
			
		||||
	"general":           8192, // 科大讯飞
 | 
			
		||||
	"general2":          8192,
 | 
			
		||||
// PowerType 算力日志类型
 | 
			
		||||
type PowerType int
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	PowerRecharge = PowerType(1) // 充值
 | 
			
		||||
	PowerConsume  = PowerType(2) // 消费
 | 
			
		||||
	PowerRefund   = PowerType(3) // 任务(SD,MJ)执行失败,退款
 | 
			
		||||
	PowerInvite   = PowerType(4) // 邀请奖励
 | 
			
		||||
	PowerReward   = PowerType(5) // 众筹
 | 
			
		||||
	PowerGift     = PowerType(6) // 系统赠送
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
func (t PowerType) String() string {
 | 
			
		||||
	switch t {
 | 
			
		||||
	case PowerRecharge:
 | 
			
		||||
		return "充值"
 | 
			
		||||
	case PowerConsume:
 | 
			
		||||
		return "消费"
 | 
			
		||||
	case PowerRefund:
 | 
			
		||||
		return "退款"
 | 
			
		||||
	case PowerReward:
 | 
			
		||||
		return "众筹"
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	return "其他"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type PowerMark int
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	PowerSub = PowerMark(0)
 | 
			
		||||
	PowerAdd = PowerMark(1)
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
 
 | 
			
		||||
@@ -1,67 +1,79 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AppConfig struct {
 | 
			
		||||
	Path          string `toml:"-"`
 | 
			
		||||
	Listen        string
 | 
			
		||||
	Session       Session
 | 
			
		||||
	ProxyURL      string
 | 
			
		||||
	MysqlDns      string            // mysql 连接地址
 | 
			
		||||
	Manager       Manager           // 后台管理员账户信息
 | 
			
		||||
	StaticDir     string            // 静态资源目录
 | 
			
		||||
	StaticUrl     string            // 静态资源 URL
 | 
			
		||||
	Redis         RedisConfig       // redis 连接信息
 | 
			
		||||
	ApiConfig     ChatPlusApiConfig // ChatPlus API authorization configs
 | 
			
		||||
	AesEncryptKey string
 | 
			
		||||
	SmsConfig     AliYunSmsConfig       // AliYun send message service config
 | 
			
		||||
	OSS           OSSConfig             // OSS config
 | 
			
		||||
	MjConfig      MidJourneyConfig      // mj 绘画配置
 | 
			
		||||
	WeChatBot     bool                  // 是否启用微信机器人
 | 
			
		||||
	SdConfig      StableDiffusionConfig // sd 绘画配置
 | 
			
		||||
	Path           string `toml:"-"`
 | 
			
		||||
	Listen         string
 | 
			
		||||
	Session        Session
 | 
			
		||||
	AdminSession   Session
 | 
			
		||||
	ProxyURL       string
 | 
			
		||||
	MysqlDns       string                  // mysql 连接地址
 | 
			
		||||
	StaticDir      string                  // 静态资源目录
 | 
			
		||||
	StaticUrl      string                  // 静态资源 URL
 | 
			
		||||
	Redis          RedisConfig             // redis 连接信息
 | 
			
		||||
	ApiConfig      ApiConfig               // ChatPlus API authorization configs
 | 
			
		||||
	SMS            SMSConfig               // send mobile message config
 | 
			
		||||
	OSS            OSSConfig               // OSS config
 | 
			
		||||
	MjProxyConfigs []MjProxyConfig         // MJ proxy config
 | 
			
		||||
	MjPlusConfigs  []MjPlusConfig          // MJ plus config
 | 
			
		||||
	WeChatBot      bool                    // 是否启用微信机器人
 | 
			
		||||
	SdConfigs      []StableDiffusionConfig // sd AI draw service pool
 | 
			
		||||
 | 
			
		||||
	XXLConfig    XXLConfig
 | 
			
		||||
	AlipayConfig AlipayConfig
 | 
			
		||||
	XXLConfig     XXLConfig
 | 
			
		||||
	AlipayConfig  AlipayConfig
 | 
			
		||||
	HuPiPayConfig HuPiPayConfig
 | 
			
		||||
	SmtpConfig    SmtpConfig // 邮件发送配置
 | 
			
		||||
	JPayConfig    JPayConfig // payjs 支付配置
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ChatPlusApiConfig struct {
 | 
			
		||||
type SmtpConfig struct {
 | 
			
		||||
	UseTls   bool // 是否使用 TLS 发送
 | 
			
		||||
	Host     string
 | 
			
		||||
	Port     int
 | 
			
		||||
	AppName  string // 应用名称
 | 
			
		||||
	From     string // 发件人邮箱地址
 | 
			
		||||
	Password string // 发件人邮箱密码
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ApiConfig struct {
 | 
			
		||||
	ApiURL string
 | 
			
		||||
	AppId  string
 | 
			
		||||
	Token  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type MidJourneyConfig struct {
 | 
			
		||||
	Enabled   bool
 | 
			
		||||
	UserToken string
 | 
			
		||||
	BotToken  string
 | 
			
		||||
	GuildId   string // Server ID
 | 
			
		||||
	ChanelId  string // Chanel ID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type WeChatConfig struct {
 | 
			
		||||
type MjProxyConfig struct {
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type StableDiffusionConfig struct {
 | 
			
		||||
	Enabled         bool
 | 
			
		||||
	ApiURL          string
 | 
			
		||||
	ApiKey          string
 | 
			
		||||
	Txt2ImgJsonPath string
 | 
			
		||||
	Enabled bool
 | 
			
		||||
	Model   string // 模型名称
 | 
			
		||||
	ApiURL  string
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AliYunSmsConfig struct {
 | 
			
		||||
	AccessKey    string
 | 
			
		||||
	AccessSecret string
 | 
			
		||||
	Product      string
 | 
			
		||||
	Domain       string
 | 
			
		||||
	Sign         string // 短信签名
 | 
			
		||||
	CodeTempId   string // 验证码短信模板 ID
 | 
			
		||||
type MjPlusConfig struct {
 | 
			
		||||
	Enabled bool   // 如果启用了 MidJourney Plus,将会自动禁用原生的MidJourney服务
 | 
			
		||||
	ApiURL  string // api 地址
 | 
			
		||||
	Mode    string // 绘画模式,可选值:fast/turbo/relax
 | 
			
		||||
	ApiKey  string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type AlipayConfig struct {
 | 
			
		||||
	Enabled         bool   // 是否启用该服务
 | 
			
		||||
	Enabled         bool   // 是否启用该支付通道
 | 
			
		||||
	SandBox         bool   // 是否沙盒环境
 | 
			
		||||
	AppId           string // 应用 ID
 | 
			
		||||
	UserId          string // 支付宝用户 ID
 | 
			
		||||
@@ -70,6 +82,28 @@ type AlipayConfig struct {
 | 
			
		||||
	AlipayPublicKey string // 支付宝公钥文件路径
 | 
			
		||||
	RootCert        string // Root 秘钥路径
 | 
			
		||||
	NotifyURL       string // 异步通知回调
 | 
			
		||||
	ReturnURL       string // 支付成功返回地址
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type HuPiPayConfig struct { //虎皮椒第四方支付配置
 | 
			
		||||
	Enabled   bool   // 是否启用该支付通道
 | 
			
		||||
	Name      string // 支付名称,如:wechat/alipay
 | 
			
		||||
	AppId     string // App ID
 | 
			
		||||
	AppSecret string // app 密钥
 | 
			
		||||
	ApiURL    string // 支付网关
 | 
			
		||||
	NotifyURL string // 异步通知回调
 | 
			
		||||
	ReturnURL string // 支付成功返回地址
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JPayConfig PayJs 支付配置
 | 
			
		||||
type JPayConfig struct {
 | 
			
		||||
	Enabled    bool
 | 
			
		||||
	Name       string // 支付名称,默认 wechat
 | 
			
		||||
	AppId      string // 商户 ID
 | 
			
		||||
	PrivateKey string // 私钥
 | 
			
		||||
	ApiURL     string // API 网关
 | 
			
		||||
	NotifyURL  string // 异步回调地址
 | 
			
		||||
	ReturnURL  string // 支付成功返回地址
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type XXLConfig struct { // XXL 任务调度配置
 | 
			
		||||
@@ -88,29 +122,21 @@ type RedisConfig struct {
 | 
			
		||||
	DB       int
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// LicenseKey 存储许可证书的 KEY
 | 
			
		||||
const LicenseKey = "Geek-AI-License"
 | 
			
		||||
 | 
			
		||||
type License struct {
 | 
			
		||||
	Key       string `json:"key"`        // 许可证书密钥
 | 
			
		||||
	MachineId string `json:"machine_id"` // 机器码
 | 
			
		||||
	UserNum   int    `json:"user_num"`   // 用户数量
 | 
			
		||||
	ExpiredAt int64  `json:"expired_at"` // 过期时间
 | 
			
		||||
	IsActive  bool   `json:"is_active"`  // 是否激活
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c RedisConfig) Url() string {
 | 
			
		||||
	return fmt.Sprintf("%s:%d", c.Host, c.Port)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Manager 管理员
 | 
			
		||||
type Manager struct {
 | 
			
		||||
	Username string `json:"username"`
 | 
			
		||||
	Password string `json:"password"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatConfig 系统默认的聊天配置
 | 
			
		||||
type ChatConfig struct {
 | 
			
		||||
	OpenAI  ModelAPIConfig `json:"open_ai"`
 | 
			
		||||
	Azure   ModelAPIConfig `json:"azure"`
 | 
			
		||||
	ChatGML ModelAPIConfig `json:"chat_gml"`
 | 
			
		||||
	Baidu   ModelAPIConfig `json:"baidu"`
 | 
			
		||||
	XunFei  ModelAPIConfig `json:"xun_fei"`
 | 
			
		||||
 | 
			
		||||
	EnableContext bool `json:"enable_context"` // 是否开启聊天上下文
 | 
			
		||||
	EnableHistory bool `json:"enable_history"` // 是否允许保存聊天记录
 | 
			
		||||
	ContextDeep   int  `json:"context_deep"`   // 上下文深度
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Platform string
 | 
			
		||||
 | 
			
		||||
const OpenAI = Platform("OpenAI")
 | 
			
		||||
@@ -118,33 +144,37 @@ const Azure = Platform("Azure")
 | 
			
		||||
const ChatGLM = Platform("ChatGLM")
 | 
			
		||||
const Baidu = Platform("Baidu")
 | 
			
		||||
const XunFei = Platform("XunFei")
 | 
			
		||||
 | 
			
		||||
// UserChatConfig 用户的聊天配置
 | 
			
		||||
type UserChatConfig struct {
 | 
			
		||||
	ApiKeys map[Platform]string `json:"api_keys"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ModelAPIConfig struct {
 | 
			
		||||
	ApiURL      string  `json:"api_url,omitempty"`
 | 
			
		||||
	Temperature float32 `json:"temperature"`
 | 
			
		||||
	MaxTokens   int     `json:"max_tokens"`
 | 
			
		||||
	ApiKey      string  `json:"api_key"`
 | 
			
		||||
}
 | 
			
		||||
const QWen = Platform("QWen")
 | 
			
		||||
 | 
			
		||||
type SystemConfig struct {
 | 
			
		||||
	Title           string   `json:"title"`
 | 
			
		||||
	AdminTitle      string   `json:"admin_title"`
 | 
			
		||||
	Models          []string `json:"models"`
 | 
			
		||||
	UserInitCalls   int      `json:"user_init_calls"` // 新用户注册默认总送多少次调用
 | 
			
		||||
	InitImgCalls    int      `json:"init_img_calls"`
 | 
			
		||||
	VipMonthCalls   int      `json:"vip_month_calls"` // 会员每个赠送的调用次数
 | 
			
		||||
	EnabledRegister bool     `json:"enabled_register"`
 | 
			
		||||
	EnabledMsg      bool     `json:"enabled_msg"`       // 启用短信验证码服务
 | 
			
		||||
	EnabledDraw     bool     `json:"enabled_draw"`      // 启动 AI 绘画功能
 | 
			
		||||
	RewardImg       string   `json:"reward_img"`        // 众筹收款二维码地址
 | 
			
		||||
	EnabledFunction bool     `json:"enabled_function"`  // 启用 API 函数功能
 | 
			
		||||
	EnabledReward   bool     `json:"enabled_reward"`    // 启用众筹功能
 | 
			
		||||
	EnabledAlipay   bool     `json:"enabled_alipay"`    // 是否启用支付宝支付通道
 | 
			
		||||
	OrderPayTimeout int      `json:"order_pay_timeout"` //订单支付超时时间
 | 
			
		||||
	DefaultModels   []string `json:"default_models"`    // 默认开通的 AI 模型
 | 
			
		||||
	Title         string `json:"title,omitempty"`
 | 
			
		||||
	AdminTitle    string `json:"admin_title,omitempty"`
 | 
			
		||||
	Logo          string `json:"logo,omitempty"`
 | 
			
		||||
	InitPower     int    `json:"init_power,omitempty"`      // 新用户注册赠送算力值
 | 
			
		||||
	DailyPower    int    `json:"daily_power,omitempty"`     // 每日赠送算力
 | 
			
		||||
	InvitePower   int    `json:"invite_power,omitempty"`    // 邀请新用户赠送算力值
 | 
			
		||||
	VipMonthPower int    `json:"vip_month_power,omitempty"` // VIP 会员每月赠送的算力值
 | 
			
		||||
 | 
			
		||||
	RegisterWays    []string `json:"register_ways,omitempty"`    // 注册方式:支持手机(mobile),邮箱注册(email),账号密码注册
 | 
			
		||||
	EnabledRegister bool     `json:"enabled_register,omitempty"` // 是否开放注册
 | 
			
		||||
 | 
			
		||||
	RewardImg     string  `json:"reward_img,omitempty"`     // 众筹收款二维码地址
 | 
			
		||||
	EnabledReward bool    `json:"enabled_reward,omitempty"` // 启用众筹功能
 | 
			
		||||
	PowerPrice    float64 `json:"power_price,omitempty"`    // 算力单价
 | 
			
		||||
 | 
			
		||||
	OrderPayTimeout int    `json:"order_pay_timeout,omitempty"` //订单支付超时时间
 | 
			
		||||
	VipInfoText     string `json:"vip_info_text,omitempty"`     // 会员页面充值说明
 | 
			
		||||
	DefaultModels   []int  `json:"default_models,omitempty"`    // 默认开通的 AI 模型
 | 
			
		||||
 | 
			
		||||
	MjPower       int `json:"mj_power,omitempty"`        // MJ 绘画消耗算力
 | 
			
		||||
	MjActionPower int `json:"mj_action_power,omitempty"` // MJ 操作(放大,变换)消耗算力
 | 
			
		||||
	SdPower       int `json:"sd_power,omitempty"`        // SD 绘画消耗算力
 | 
			
		||||
	DallPower     int `json:"dall_power,omitempty"`      // DALLE3 绘图消耗算力
 | 
			
		||||
 | 
			
		||||
	WechatCardURL string `json:"wechat_card_url,omitempty"` // 微信客服地址
 | 
			
		||||
 | 
			
		||||
	EnableContext bool `json:"enable_context,omitempty"`
 | 
			
		||||
	ContextDeep   int  `json:"context_deep,omitempty"`
 | 
			
		||||
 | 
			
		||||
	SdNegPrompt string `json:"sd_neg_prompt"` // SD 默认反向提示词
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,92 +1,28 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
type FunctionCall struct {
 | 
			
		||||
	Name      string `json:"name"`
 | 
			
		||||
	Arguments string `json:"arguments"`
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
type ToolCall struct {
 | 
			
		||||
	Type     string `json:"type"`
 | 
			
		||||
	Function struct {
 | 
			
		||||
		Name      string `json:"name"`
 | 
			
		||||
		Arguments string `json:"arguments"`
 | 
			
		||||
	} `json:"function"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Tool struct {
 | 
			
		||||
	Type     string   `json:"type"`
 | 
			
		||||
	Function Function `json:"function"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Function struct {
 | 
			
		||||
	Name        string     `json:"name"`
 | 
			
		||||
	Description string     `json:"description"`
 | 
			
		||||
	Parameters  Parameters `json:"parameters"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Parameters struct {
 | 
			
		||||
	Type       string              `json:"type"`
 | 
			
		||||
	Required   []string            `json:"required"`
 | 
			
		||||
	Properties map[string]Property `json:"properties"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Property struct {
 | 
			
		||||
	Type        string `json:"type"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	FuncZaoBao     = "zao_bao"     // 每日早报
 | 
			
		||||
	FuncHeadLine   = "headline"    // 今日头条
 | 
			
		||||
	FuncWeibo      = "weibo_hot"   // 微博热搜
 | 
			
		||||
	FuncMidJourney = "mid_journey" // MJ 绘画
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var InnerFunctions = []Function{
 | 
			
		||||
	{
 | 
			
		||||
		Name:        FuncZaoBao,
 | 
			
		||||
		Description: "每日早报,获取当天全球的热门新闻事件列表",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"text": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
	{
 | 
			
		||||
		Name:        FuncWeibo,
 | 
			
		||||
		Description: "新浪微博热搜榜,微博当日热搜榜单",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"text": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		Name:        FuncHeadLine,
 | 
			
		||||
		Description: "今日头条,给用户推荐当天的头条新闻,周榜热文",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"text": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
 | 
			
		||||
	{
 | 
			
		||||
		Name:        FuncMidJourney,
 | 
			
		||||
		Description: "AI 绘画工具,使用 MJ MidJourney API 进行 AI 绘画",
 | 
			
		||||
		Parameters: Parameters{
 | 
			
		||||
			Type: "object",
 | 
			
		||||
			Properties: map[string]Property{
 | 
			
		||||
				"prompt": {
 | 
			
		||||
					Type:        "string",
 | 
			
		||||
					Description: "提示词,如果该参数中有中文的话,则需要翻译成英文。提示词中的参数作为提示的一部分,不要删除",
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			Required: []string{},
 | 
			
		||||
		},
 | 
			
		||||
	},
 | 
			
		||||
	Name        string                 `json:"name"`
 | 
			
		||||
	Description string                 `json:"description"`
 | 
			
		||||
	Parameters  map[string]interface{} `json:"parameters"`
 | 
			
		||||
	Required    interface{}            `json:"required,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,15 +1,22 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"context"
 | 
			
		||||
	"sync"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MKey interface {
 | 
			
		||||
	string | int
 | 
			
		||||
	string | int | uint
 | 
			
		||||
}
 | 
			
		||||
type MValue interface {
 | 
			
		||||
	*WsClient | *ChatSession | context.CancelFunc | []interface{}
 | 
			
		||||
	*WsClient | *ChatSession | context.CancelFunc | []Message
 | 
			
		||||
}
 | 
			
		||||
type LMap[K MKey, T MValue] struct {
 | 
			
		||||
	lock sync.RWMutex
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
type OrderStatus int
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
@@ -10,7 +17,7 @@ const (
 | 
			
		||||
 | 
			
		||||
type OrderRemark struct {
 | 
			
		||||
	Days     int     `json:"days"`  // 有效期
 | 
			
		||||
	Calls    int     `json:"calls"` // 增加调用次数
 | 
			
		||||
	Power    int     `json:"power"` // 增加算力点数
 | 
			
		||||
	Name     string  `json:"name"`  // 产品名称
 | 
			
		||||
	Price    float64 `json:"price"`
 | 
			
		||||
	Discount float64 `json:"discount"`
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
type OSSConfig struct {
 | 
			
		||||
	Active string
 | 
			
		||||
	Local  LocalStorageConfig
 | 
			
		||||
@@ -12,6 +19,7 @@ type MiniOssConfig struct {
 | 
			
		||||
	AccessKey    string
 | 
			
		||||
	AccessSecret string
 | 
			
		||||
	Bucket       string
 | 
			
		||||
	SubDir       string
 | 
			
		||||
	UseSSL       bool
 | 
			
		||||
	Domain       string
 | 
			
		||||
}
 | 
			
		||||
@@ -21,6 +29,7 @@ type QiNiuOssConfig struct {
 | 
			
		||||
	AccessKey    string
 | 
			
		||||
	AccessSecret string
 | 
			
		||||
	Bucket       string
 | 
			
		||||
	SubDir       string
 | 
			
		||||
	Domain       string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -29,6 +38,7 @@ type AliYunOssConfig struct {
 | 
			
		||||
	AccessKey    string
 | 
			
		||||
	AccessSecret string
 | 
			
		||||
	Bucket       string
 | 
			
		||||
	SubDir       string
 | 
			
		||||
	Domain       string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,11 +1,17 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
const LoginUserID = "LOGIN_USER_ID"
 | 
			
		||||
const LoginUserCache = "LOGIN_USER_CACHE"
 | 
			
		||||
 | 
			
		||||
const UserAuthHeader = "Authorization"
 | 
			
		||||
const AdminAuthHeader = "Admin-Authorization"
 | 
			
		||||
const ChatTokenHeader = "Chat-Token"
 | 
			
		||||
 | 
			
		||||
// Session configs struct
 | 
			
		||||
type Session struct {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										33
									
								
								api/core/types/sms.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								api/core/types/sms.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,33 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
type SMSConfig struct {
 | 
			
		||||
	Active string
 | 
			
		||||
	Ali    SmsConfigAli
 | 
			
		||||
	Bao    SmsConfigBao
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SmsConfigAli 阿里云短信平台配置
 | 
			
		||||
type SmsConfigAli struct {
 | 
			
		||||
	AccessKey    string
 | 
			
		||||
	AccessSecret string
 | 
			
		||||
	Product      string
 | 
			
		||||
	Domain       string
 | 
			
		||||
	Sign         string // 短信签名
 | 
			
		||||
	CodeTempId   string // 验证码短信模板 ID
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SmsConfigBao 短信宝平台配置
 | 
			
		||||
type SmsConfigBao struct {
 | 
			
		||||
	Username     string //短信宝平台注册的用户名
 | 
			
		||||
	Password     string //短信宝平台注册的密码
 | 
			
		||||
	Domain       string //域名
 | 
			
		||||
	Sign         string // 短信签名
 | 
			
		||||
	CodeTemplate string // 验证码短信模板 匹配
 | 
			
		||||
}
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
// TaskType 任务类别
 | 
			
		||||
type TaskType string
 | 
			
		||||
 | 
			
		||||
@@ -9,30 +16,24 @@ func (t TaskType) String() string {
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TaskImage     = TaskType("image")
 | 
			
		||||
	TaskBlend     = TaskType("blend")
 | 
			
		||||
	TaskSwapFace  = TaskType("swapFace")
 | 
			
		||||
	TaskUpscale   = TaskType("upscale")
 | 
			
		||||
	TaskVariation = TaskType("variation")
 | 
			
		||||
	TaskTxt2Img   = TaskType("text2img")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// TaskSrc 任务来源
 | 
			
		||||
type TaskSrc string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	TaskSrcChat = TaskSrc("chat") // 来自聊天页面
 | 
			
		||||
	TaskSrcImg  = TaskSrc("img")  // 专业绘画页面
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MjTask MidJourney 任务
 | 
			
		||||
type MjTask struct {
 | 
			
		||||
	Id          int      `json:"id"`
 | 
			
		||||
	Id          uint     `json:"id"`
 | 
			
		||||
	TaskId      string   `json:"task_id"`
 | 
			
		||||
	ImgArr      []string `json:"img_arr"`
 | 
			
		||||
	ChannelId   string   `json:"channel_id"`
 | 
			
		||||
	SessionId   string   `json:"session_id"`
 | 
			
		||||
	Src         TaskSrc  `json:"src"`
 | 
			
		||||
	Type        TaskType `json:"type"`
 | 
			
		||||
	UserId      int      `json:"user_id"`
 | 
			
		||||
	Prompt      string   `json:"prompt,omitempty"`
 | 
			
		||||
	ChatId      string   `json:"chat_id,omitempty"`
 | 
			
		||||
	RoleId      int      `json:"role_id,omitempty"`
 | 
			
		||||
	Icon        string   `json:"icon,omitempty"`
 | 
			
		||||
	NegPrompt   string   `json:"neg_prompt,omitempty"`
 | 
			
		||||
	Params      string   `json:"full_prompt"`
 | 
			
		||||
	Index       int      `json:"index,omitempty"`
 | 
			
		||||
	MessageId   string   `json:"message_id,omitempty"`
 | 
			
		||||
	MessageHash string   `json:"message_hash,omitempty"`
 | 
			
		||||
@@ -42,28 +43,40 @@ type MjTask struct {
 | 
			
		||||
type SdTask struct {
 | 
			
		||||
	Id         int          `json:"id"` // job 数据库ID
 | 
			
		||||
	SessionId  string       `json:"session_id"`
 | 
			
		||||
	Src        TaskSrc      `json:"src"`
 | 
			
		||||
	Type       TaskType     `json:"type"`
 | 
			
		||||
	UserId     int          `json:"user_id"`
 | 
			
		||||
	Prompt     string       `json:"prompt,omitempty"`
 | 
			
		||||
	Params     SdTaskParams `json:"params"`
 | 
			
		||||
	RetryCount int          `json:"retry_count"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type SdTaskParams struct {
 | 
			
		||||
	TaskId         string  `json:"task_id"`
 | 
			
		||||
	Prompt         string  `json:"prompt"`          // 提示词
 | 
			
		||||
	NegativePrompt string  `json:"negative_prompt"` // 反向提示词
 | 
			
		||||
	Steps          int     `json:"steps"`           // 迭代步数,默认20
 | 
			
		||||
	Sampler        string  `json:"sampler"`         // 采样器
 | 
			
		||||
	FaceFix        bool    `json:"face_fix"`        // 面部修复
 | 
			
		||||
	CfgScale       float32 `json:"cfg_scale"`       //引导系数,默认 7
 | 
			
		||||
	Seed           int64   `json:"seed"`            // 随机数种子
 | 
			
		||||
	Height         int     `json:"height"`
 | 
			
		||||
	Width          int     `json:"width"`
 | 
			
		||||
	HdFix          bool    `json:"hd_fix"`         // 启用高清修复
 | 
			
		||||
	HdRedrawRate   float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
 | 
			
		||||
	HdScale        int     `json:"hd_scale"`       // 放大倍数
 | 
			
		||||
	HdScaleAlg     string  `json:"hd_scale_alg"`   // 放大算法
 | 
			
		||||
	HdSteps        int     `json:"hd_steps"`       // 高清修复迭代步数
 | 
			
		||||
	TaskId       string  `json:"task_id"`
 | 
			
		||||
	Prompt       string  `json:"prompt"`     // 提示词
 | 
			
		||||
	NegPrompt    string  `json:"neg_prompt"` // 反向提示词
 | 
			
		||||
	Steps        int     `json:"steps"`      // 迭代步数,默认20
 | 
			
		||||
	Sampler      string  `json:"sampler"`    // 采样器
 | 
			
		||||
	Scheduler    string  `json:"scheduler"`
 | 
			
		||||
	FaceFix      bool    `json:"face_fix"`  // 面部修复
 | 
			
		||||
	CfgScale     float32 `json:"cfg_scale"` //引导系数,默认 7
 | 
			
		||||
	Seed         int64   `json:"seed"`      // 随机数种子
 | 
			
		||||
	Height       int     `json:"height"`
 | 
			
		||||
	Width        int     `json:"width"`
 | 
			
		||||
	HdFix        bool    `json:"hd_fix"`         // 启用高清修复
 | 
			
		||||
	HdRedrawRate float32 `json:"hd_redraw_rate"` // 高清修复重绘幅度
 | 
			
		||||
	HdScale      int     `json:"hd_scale"`       // 放大倍数
 | 
			
		||||
	HdScaleAlg   string  `json:"hd_scale_alg"`   // 放大算法
 | 
			
		||||
	HdSteps      int     `json:"hd_steps"`       // 高清修复迭代步数
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// DallTask DALL-E task
 | 
			
		||||
type DallTask struct {
 | 
			
		||||
	JobId   uint   `json:"job_id"`
 | 
			
		||||
	UserId  uint   `json:"user_id"`
 | 
			
		||||
	Prompt  string `json:"prompt"`
 | 
			
		||||
	N       int    `json:"n"`
 | 
			
		||||
	Quality string `json:"quality"`
 | 
			
		||||
	Size    string `json:"size"`
 | 
			
		||||
	Style   string `json:"style"`
 | 
			
		||||
 | 
			
		||||
	Power int `json:"power"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package types
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
// BizVo 业务返回 VO
 | 
			
		||||
type BizVo struct {
 | 
			
		||||
	Code     BizCode     `json:"code"`
 | 
			
		||||
@@ -21,7 +28,7 @@ const (
 | 
			
		||||
	WsStart  = WsMsgType("start")
 | 
			
		||||
	WsMiddle = WsMsgType("middle")
 | 
			
		||||
	WsEnd    = WsMsgType("end")
 | 
			
		||||
	WsMjImg  = WsMsgType("mj")
 | 
			
		||||
	WsErr    = WsMsgType("error")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type BizCode int
 | 
			
		||||
@@ -30,6 +37,7 @@ const (
 | 
			
		||||
	Success       = BizCode(0)
 | 
			
		||||
	Failed        = BizCode(1)
 | 
			
		||||
	NotAuthorized = BizCode(400) // 未授权
 | 
			
		||||
	NotPermission = BizCode(403) // 没有权限
 | 
			
		||||
 | 
			
		||||
	OkMsg       = "Success"
 | 
			
		||||
	ErrorMsg    = "系统开小差了"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										24
									
								
								api/go.mod
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								api/go.mod
									
									
									
									
									
								
							@@ -1,4 +1,4 @@
 | 
			
		||||
module chatplus
 | 
			
		||||
module geekai
 | 
			
		||||
 | 
			
		||||
go 1.19
 | 
			
		||||
 | 
			
		||||
@@ -6,7 +6,6 @@ require (
 | 
			
		||||
	github.com/BurntSushi/toml v1.1.0
 | 
			
		||||
	github.com/aliyun/alibaba-cloud-sdk-go v1.62.405
 | 
			
		||||
	github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible
 | 
			
		||||
	github.com/bwmarrin/discordgo v0.27.1
 | 
			
		||||
	github.com/eatmoreapple/openwechat v1.2.1
 | 
			
		||||
	github.com/gin-gonic/gin v1.9.1
 | 
			
		||||
	github.com/go-redis/redis/v8 v8.11.5
 | 
			
		||||
@@ -19,7 +18,6 @@ require (
 | 
			
		||||
	github.com/qiniu/go-sdk/v7 v7.17.1
 | 
			
		||||
	github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e
 | 
			
		||||
	github.com/smartwalle/alipay/v3 v3.2.15
 | 
			
		||||
	github.com/syndtr/goleveldb v1.0.0
 | 
			
		||||
	go.uber.org/zap v1.23.0
 | 
			
		||||
	gopkg.in/natefinch/lumberjack.v2 v2.2.1
 | 
			
		||||
	gorm.io/driver/mysql v1.4.7
 | 
			
		||||
@@ -27,6 +25,23 @@ require (
 | 
			
		||||
 | 
			
		||||
require github.com/xxl-job/xxl-job-executor-go v1.2.0
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/mojocn/base64Captcha v1.3.1
 | 
			
		||||
	github.com/shirou/gopsutil v3.21.11+incompatible
 | 
			
		||||
	github.com/shopspring/decimal v1.3.1
 | 
			
		||||
	github.com/syndtr/goleveldb v1.0.0
 | 
			
		||||
	golang.org/x/image v0.0.0-20211028202545-6944b10bf410
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/go-ole/go-ole v1.2.6 // indirect
 | 
			
		||||
	github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 // indirect
 | 
			
		||||
	github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db // indirect
 | 
			
		||||
	github.com/tklauser/go-sysconf v0.3.13 // indirect
 | 
			
		||||
	github.com/tklauser/numcpus v0.7.0 // indirect
 | 
			
		||||
	github.com/yusufpapurcu/wmi v1.2.4 // indirect
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
require (
 | 
			
		||||
	github.com/andybalholm/brotli v1.0.4 // indirect
 | 
			
		||||
	github.com/bytedance/sonic v1.9.1 // indirect
 | 
			
		||||
@@ -87,7 +102,6 @@ require (
 | 
			
		||||
	github.com/go-playground/locales v0.14.1 // indirect
 | 
			
		||||
	github.com/go-playground/universal-translator v0.18.1 // indirect
 | 
			
		||||
	github.com/go-playground/validator/v10 v10.14.0 // indirect
 | 
			
		||||
	github.com/golang/snappy v0.0.1 // indirect
 | 
			
		||||
	github.com/json-iterator/go v1.1.12 // indirect
 | 
			
		||||
	github.com/leodido/go-urn v1.2.4 // indirect
 | 
			
		||||
	github.com/mattn/go-isatty v0.0.19 // indirect
 | 
			
		||||
@@ -98,6 +112,6 @@ require (
 | 
			
		||||
	go.uber.org/fx v1.19.3
 | 
			
		||||
	go.uber.org/multierr v1.6.0 // indirect
 | 
			
		||||
	golang.org/x/crypto v0.12.0
 | 
			
		||||
	golang.org/x/sys v0.11.0 // indirect
 | 
			
		||||
	golang.org/x/sys v0.15.0 // indirect
 | 
			
		||||
	gorm.io/gorm v1.25.1
 | 
			
		||||
)
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										31
									
								
								api/go.sum
									
									
									
									
									
								
							
							
						
						
									
										31
									
								
								api/go.sum
									
									
									
									
									
								
							@@ -7,8 +7,6 @@ github.com/aliyun/aliyun-oss-go-sdk v2.2.9+incompatible/go.mod h1:T/Aws4fEfogEE9
 | 
			
		||||
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
 | 
			
		||||
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
 | 
			
		||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
 | 
			
		||||
github.com/bwmarrin/discordgo v0.27.1 h1:ib9AIc/dom1E/fSIulrBwnez0CToJE113ZGt4HoliGY=
 | 
			
		||||
github.com/bwmarrin/discordgo v0.27.1/go.mod h1:NJZpH+1AfhIcyQsPeuBKsUtYrRnjkyu0kIVMCHkZtRY=
 | 
			
		||||
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s=
 | 
			
		||||
github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U=
 | 
			
		||||
@@ -42,6 +40,8 @@ github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SU
 | 
			
		||||
github.com/go-basic/ipv4 v1.0.0 h1:gjyFAa1USC1hhXTkPOwBWDPfMcUaIM+tvo1XzV9EZxs=
 | 
			
		||||
github.com/go-basic/ipv4 v1.0.0/go.mod h1:etLBnaxbidQfuqE6wgZQfs38nEWNmzALkxDZe4xY8Dg=
 | 
			
		||||
github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
 | 
			
		||||
github.com/go-ole/go-ole v1.2.6 h1:/Fpf6oFPoeFik9ty7siob0G6Ke8QvQEuVcuChpwXzpY=
 | 
			
		||||
github.com/go-ole/go-ole v1.2.6/go.mod h1:pprOEPIfldk/42T2oK7lQ4v4JSDwmV0As9GaiUsvbm0=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4=
 | 
			
		||||
github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s=
 | 
			
		||||
github.com/go-playground/locales v0.13.0/go.mod h1:taPMhCMXrRLJO55olJkUXHZBHCxTMfnGwq/HNwmWNS8=
 | 
			
		||||
@@ -66,14 +66,15 @@ github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MG
 | 
			
		||||
github.com/goji/httpauth v0.0.0-20160601135302-2da839ab0f4d/go.mod h1:nnjvkQ9ptGaCkuDUx6wNykzzlUixGxvkme+H/lnzb+A=
 | 
			
		||||
github.com/golang-jwt/jwt/v5 v5.0.0 h1:1n1XNM9hk7O9mnQoNBGolZvzebBQ7p93ULHRc28XJUE=
 | 
			
		||||
github.com/golang-jwt/jwt/v5 v5.0.0/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk=
 | 
			
		||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0 h1:DACJavvAHhabrF08vX0COfcOBJRhZ8lUbR+ZWIs0Y5g=
 | 
			
		||||
github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k=
 | 
			
		||||
github.com/golang/mock v1.6.0 h1:ErTB+efbowRARo13NNdxyJji2egdxLGQhRaY+DUumQc=
 | 
			
		||||
github.com/golang/mock v1.6.0/go.mod h1:p6yTPP+5HYm5mzsMV8JkE6ZKdX+/wYM6Hr+LicevLPs=
 | 
			
		||||
github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
 | 
			
		||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
 | 
			
		||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
 | 
			
		||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db h1:woRePGFeVFfLKN/pOkfl+p/TAqKOfFu+7KPlMVpok/w=
 | 
			
		||||
github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
 | 
			
		||||
github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4=
 | 
			
		||||
github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
 | 
			
		||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
 | 
			
		||||
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
 | 
			
		||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
 | 
			
		||||
@@ -81,7 +82,6 @@ github.com/google/pprof v0.0.0-20230602150820-91b7bce49751 h1:hR7/MlvK23p6+lIw9S
 | 
			
		||||
github.com/google/pprof v0.0.0-20230602150820-91b7bce49751/go.mod h1:Jh3hGz2jkYak8qXPD19ryItVnUgpgeqzdkY/D0EaeuA=
 | 
			
		||||
github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I=
 | 
			
		||||
github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
 | 
			
		||||
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
 | 
			
		||||
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
 | 
			
		||||
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
 | 
			
		||||
@@ -135,6 +135,8 @@ github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJ
 | 
			
		||||
github.com/modern-go/reflect2 v1.0.1/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
 | 
			
		||||
github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M=
 | 
			
		||||
github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk=
 | 
			
		||||
github.com/mojocn/base64Captcha v1.3.1 h1:2Wbkt8Oc8qjmNJ5GyOfSo4tgVQPsbKMftqASnq8GlT0=
 | 
			
		||||
github.com/mojocn/base64Captcha v1.3.1/go.mod h1:wAQCKEc5bDujxKRmbT6/vTnTt5CjStQ8bRfPWUuz/iY=
 | 
			
		||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646 h1:zYyBkD/k9seD2A7fsi6Oo2LfFZAehjjQMERAvZLEDnQ=
 | 
			
		||||
github.com/nfnt/resize v0.0.0-20180221191011-83c6a9932646/go.mod h1:jpp1/29i3P1S/RLdc7JQKbRpFeM1dOBd8T9ki5s+AY8=
 | 
			
		||||
github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE=
 | 
			
		||||
@@ -175,6 +177,10 @@ github.com/rogpeppe/go-internal v1.8.0 h1:FCbCCtXNOY3UtUuHUYaghJg4y7Fd14rXifAYUA
 | 
			
		||||
github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE=
 | 
			
		||||
github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc=
 | 
			
		||||
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
 | 
			
		||||
github.com/shirou/gopsutil v3.21.11+incompatible h1:+1+c1VGhc88SSonWP6foOcLhvnKlUeu/erjjvaPEYiI=
 | 
			
		||||
github.com/shirou/gopsutil v3.21.11+incompatible/go.mod h1:5b4v6he4MtMOwMlS0TUMTu2PcXUg8+E1lC7eC3UO/RA=
 | 
			
		||||
github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8=
 | 
			
		||||
github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
 | 
			
		||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
 | 
			
		||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
 | 
			
		||||
github.com/skip2/go-qrcode v0.0.0-20200617195104-da1b6568686e h1:MRM5ITcdelLK2j1vwZ3Je0FKVCfqOLp5zO6trqMLYs0=
 | 
			
		||||
@@ -201,6 +207,10 @@ github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gt
 | 
			
		||||
github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
 | 
			
		||||
github.com/syndtr/goleveldb v1.0.0 h1:fBdIW9lB4Iz0n9khmH8w27SJ3QEJ7+IgjPEwGSZiFdE=
 | 
			
		||||
github.com/syndtr/goleveldb v1.0.0/go.mod h1:ZVVdQEZoIme9iO1Ch2Jdy24qqXrMMOU6lpPAyBWyWuQ=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.13 h1:GBUpcahXSpR2xN01jhkNAbTLRk2Yzgggk8IM08lq3r4=
 | 
			
		||||
github.com/tklauser/go-sysconf v0.3.13/go.mod h1:zwleP4Q4OehZHGn4CYZDipCgg9usW5IJePewFCGVEa0=
 | 
			
		||||
github.com/tklauser/numcpus v0.7.0 h1:yjuerZP127QG9m5Zh/mSO4wqurYil27tHrqwRoRjpr4=
 | 
			
		||||
github.com/tklauser/numcpus v0.7.0/go.mod h1:bb6dMVcj8A42tSE7i32fsIUCbQNllK5iDguyOZRUzAY=
 | 
			
		||||
github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI=
 | 
			
		||||
github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08=
 | 
			
		||||
github.com/uber/jaeger-client-go v2.30.0+incompatible h1:D6wyKGCecFaSRUpo8lCVbaOOb6ThwMmTEbhRwtKR97o=
 | 
			
		||||
@@ -213,6 +223,8 @@ github.com/xxl-job/xxl-job-executor-go v1.2.0 h1:MTl2DpwrK2+hNjRRks2k7vB3oy+3onq
 | 
			
		||||
github.com/xxl-job/xxl-job-executor-go v1.2.0/go.mod h1:bUFhz/5Irp9zkdYk5MxhQcDDT6LlZrI8+rv5mHtQ1mo=
 | 
			
		||||
github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k=
 | 
			
		||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
 | 
			
		||||
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=
 | 
			
		||||
github.com/yusufpapurcu/wmi v1.2.4/go.mod h1:SBZ9tNy3G9/m5Oi98Zks0QjeHVDvuK0qfxQmPyzfmi0=
 | 
			
		||||
go.uber.org/atomic v1.7.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
 | 
			
		||||
go.uber.org/atomic v1.9.0 h1:ECmE8Bn/WFTYwEW/bpKD3M8VtR/zQVbavAoalC1PYyE=
 | 
			
		||||
go.uber.org/atomic v1.9.0/go.mod h1:fEN4uk6kAWBTFdckzkM89CLk9XfWZrxpCo0nPH17wJc=
 | 
			
		||||
@@ -230,7 +242,6 @@ golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k=
 | 
			
		||||
golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210421170649-83a5a9bb288b/go.mod h1:T9bdIzuCu7OtxOm1hfPfRQxPLYneinmdGuTeoZ9dtd4=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
 | 
			
		||||
golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
 | 
			
		||||
@@ -238,6 +249,9 @@ golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
 | 
			
		||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
 | 
			
		||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1 h1:k/i9J1pBpvlfR+9QsetwPyERsqu1GIbi967PQMq3Ivc=
 | 
			
		||||
golang.org/x/exp v0.0.0-20230522175609-2e198f4a06a1/go.mod h1:V1LtkGg67GoY2N1AnLN78QLrzxkLyJw7RJb1gzOOz9w=
 | 
			
		||||
golang.org/x/image v0.0.0-20190501045829-6d32002ffd75/go.mod h1:kZ7UVZpmo3dzQBMxlp+ypCbDeSB+sBbTgSJuh5dn5js=
 | 
			
		||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410 h1:hTftEOvwiOq2+O8k2D5/Q7COC7k5Qcrgc2TFURJYnvQ=
 | 
			
		||||
golang.org/x/image v0.0.0-20211028202545-6944b10bf410/go.mod h1:023OzeP/+EPmXeapQh35lcL3II3LrY8Ic+EFFKVhULM=
 | 
			
		||||
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
 | 
			
		||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
 | 
			
		||||
golang.org/x/mod v0.11.0 h1:bUO06HqtnRcc/7l71XBe4WcqTZ+3AH1J59zWDDwLKgU=
 | 
			
		||||
@@ -260,6 +274,7 @@ golang.org/x/sync v0.3.0/go.mod h1:FU7BRWz2tNW+3quACPkgCx/L+uEAv1htQ0V83Z9Rj+Y=
 | 
			
		||||
golang.org/x/sys v0.0.0-20180909124046-d0be0721c37e/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20190916202348-b4ddaad3f8a3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
 | 
			
		||||
golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
@@ -271,8 +286,8 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
 | 
			
		||||
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.11.0 h1:eG7RXZHdqOJ1i+0lgLgCpSXAp6M3LYlAo6osgSi0xOM=
 | 
			
		||||
golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
 | 
			
		||||
golang.org/x/sys v0.15.0 h1:h48lPFYpsTvQJZF4EKyI4aLHaev3CxivZmv7yZig9pc=
 | 
			
		||||
golang.org/x/sys v0.15.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
 | 
			
		||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
 | 
			
		||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 | 
			
		||||
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
 | 
			
		||||
 
 | 
			
		||||
@@ -1,17 +1,26 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"context"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"github.com/mojocn/base64Captcha"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -20,47 +29,88 @@ import (
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
// Manager 管理员
 | 
			
		||||
type Manager struct {
 | 
			
		||||
	Username  string `json:"username"`
 | 
			
		||||
	Password  string `json:"password"`
 | 
			
		||||
	Captcha   string `json:"captcha"`    // 验证码
 | 
			
		||||
	CaptchaId string `json:"captcha_id"` // 验证码id
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const SuperManagerID = 1
 | 
			
		||||
 | 
			
		||||
type ManagerHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db    *gorm.DB
 | 
			
		||||
	redis *redis.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAdminHandler(app *core.AppServer, db *gorm.DB, client *redis.Client) *ManagerHandler {
 | 
			
		||||
	h := ManagerHandler{db: db, redis: client}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &ManagerHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, redis: client}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Login 登录
 | 
			
		||||
func (h *ManagerHandler) Login(c *gin.Context) {
 | 
			
		||||
	var data types.Manager
 | 
			
		||||
	var data Manager
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	manager := h.App.Config.Manager
 | 
			
		||||
	if data.Username == manager.Username && data.Password == manager.Password {
 | 
			
		||||
		// 创建 token
 | 
			
		||||
		token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
			"user_id": manager.Username,
 | 
			
		||||
			"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
 | 
			
		||||
		})
 | 
			
		||||
		tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "Failed to generate token, "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 保存到 redis
 | 
			
		||||
		key := "users/" + manager.Username
 | 
			
		||||
		if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
 | 
			
		||||
			resp.ERROR(c, "error with save token: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		resp.SUCCESS(c, tokenString)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.ERROR(c, "用户名或者密码错误")
 | 
			
		||||
 | 
			
		||||
	// add captcha
 | 
			
		||||
	if !base64Captcha.DefaultMemStore.Verify(data.CaptchaId, data.Captcha, true) {
 | 
			
		||||
		resp.ERROR(c, "验证码错误!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var manager model.AdminUser
 | 
			
		||||
	res := h.DB.Model(&model.AdminUser{}).Where("username = ?", data.Username).First(&manager)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "请检查用户名或者密码是否填写正确")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	password := utils.GenPassword(data.Password, manager.Salt)
 | 
			
		||||
	if password != manager.Password {
 | 
			
		||||
		resp.ERROR(c, "用户名或密码错误")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 超级管理员默认是ID:1
 | 
			
		||||
	if manager.Id != SuperManagerID && manager.Status == false {
 | 
			
		||||
		resp.ERROR(c, "该用户已被禁止登录,请联系超级管理员")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 创建 token
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"user_id": manager.Id,
 | 
			
		||||
		"expired": time.Now().Add(time.Second * time.Duration(h.App.Config.Session.MaxAge)).Unix(),
 | 
			
		||||
	})
 | 
			
		||||
	tokenString, err := token.SignedString([]byte(h.App.Config.AdminSession.SecretKey))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "Failed to generate token, "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 保存到 redis
 | 
			
		||||
	key := fmt.Sprintf("admin/%d", manager.Id)
 | 
			
		||||
	if _, err := h.redis.Set(context.Background(), key, tokenString, 0).Result(); err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with save token: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新最后登录时间和IP
 | 
			
		||||
	manager.LastLoginIp = c.ClientIP()
 | 
			
		||||
	manager.LastLoginAt = time.Now().Unix()
 | 
			
		||||
	h.DB.Updates(&manager)
 | 
			
		||||
 | 
			
		||||
	var result = struct {
 | 
			
		||||
		IsSuperAdmin bool   `json:"is_super_admin"`
 | 
			
		||||
		Token        string `json:"token"`
 | 
			
		||||
	}{
 | 
			
		||||
		IsSuperAdmin: manager.Id == 1,
 | 
			
		||||
		Token:        tokenString,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, result)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Logout 注销
 | 
			
		||||
@@ -75,74 +125,155 @@ func (h *ManagerHandler) Logout(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// Session 会话检测
 | 
			
		||||
func (h *ManagerHandler) Session(c *gin.Context) {
 | 
			
		||||
	token := c.GetHeader(types.AdminAuthHeader)
 | 
			
		||||
	if token == "" {
 | 
			
		||||
	id := h.GetLoginUserId(c)
 | 
			
		||||
	key := fmt.Sprintf("admin/%d", id)
 | 
			
		||||
	if _, err := h.redis.Get(context.Background(), key).Result(); err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Migrate 数据修正
 | 
			
		||||
func (h *ManagerHandler) Migrate(c *gin.Context) {
 | 
			
		||||
	opt := c.Query("opt")
 | 
			
		||||
	switch opt {
 | 
			
		||||
	case "user":
 | 
			
		||||
		// 将用户订阅角色的数据结构从 map 改成数组
 | 
			
		||||
		var users []model.User
 | 
			
		||||
		h.db.Find(&users)
 | 
			
		||||
		for _, u := range users {
 | 
			
		||||
			var m map[string]int
 | 
			
		||||
			var roleKeys = make([]string, 0)
 | 
			
		||||
			err := utils.JsonDecode(u.ChatRoles, &m)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for k := range m {
 | 
			
		||||
				roleKeys = append(roleKeys, k)
 | 
			
		||||
			}
 | 
			
		||||
			u.ChatRoles = utils.JsonEncode(roleKeys)
 | 
			
		||||
			h.db.Updates(&u)
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
	case "role":
 | 
			
		||||
		// 修改角色图片,改成绝对路径
 | 
			
		||||
		var roles []model.ChatRole
 | 
			
		||||
		h.db.Find(&roles)
 | 
			
		||||
		for _, r := range roles {
 | 
			
		||||
			if !strings.HasPrefix(r.Icon, "/") {
 | 
			
		||||
				r.Icon = "/" + r.Icon
 | 
			
		||||
				h.db.Updates(&r)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
	case "history":
 | 
			
		||||
		// 修改角色图片,改成绝对路径
 | 
			
		||||
		var message []model.HistoryMessage
 | 
			
		||||
		h.db.Find(&message)
 | 
			
		||||
		for _, r := range message {
 | 
			
		||||
			if !strings.HasPrefix(r.Icon, "/") {
 | 
			
		||||
				r.Icon = "/" + r.Icon
 | 
			
		||||
				h.db.Updates(&r)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
 | 
			
		||||
	case "avatar":
 | 
			
		||||
		// 更新用户的头像地址
 | 
			
		||||
		var users []model.User
 | 
			
		||||
		h.db.Find(&users)
 | 
			
		||||
		for _, u := range users {
 | 
			
		||||
			if !strings.HasPrefix(u.Avatar, "/") {
 | 
			
		||||
				u.Avatar = "/" + u.Avatar
 | 
			
		||||
				h.db.Updates(&u)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
	var manager model.AdminUser
 | 
			
		||||
	res := h.DB.Where("id", id).First(&manager)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, "SUCCESS")
 | 
			
		||||
	resp.SUCCESS(c, manager)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 数据列表
 | 
			
		||||
func (h *ManagerHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.AdminUser
 | 
			
		||||
	res := h.DB.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	users := make([]vo.AdminUser, 0)
 | 
			
		||||
	for _, item := range items {
 | 
			
		||||
		var u vo.AdminUser
 | 
			
		||||
		err := utils.CopyObject(item, &u)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		u.Id = item.Id
 | 
			
		||||
		u.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
		users = append(users, u)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, users)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ManagerHandler) Save(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
		Status   bool   `json:"status"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.AdminUser
 | 
			
		||||
	res := h.DB.Where("username", data.Username).First(&user)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		resp.ERROR(c, "用户名已存在")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 生成密码
 | 
			
		||||
	salt := utils.RandString(8)
 | 
			
		||||
	password := utils.GenPassword(data.Password, salt)
 | 
			
		||||
	res = h.DB.Save(&model.AdminUser{
 | 
			
		||||
		Username: data.Username,
 | 
			
		||||
		Password: password,
 | 
			
		||||
		Salt:     salt,
 | 
			
		||||
		Status:   data.Status,
 | 
			
		||||
	})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "failed with update database")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove 删除管理员
 | 
			
		||||
func (h *ManagerHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	if id <= 0 {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if id == SuperManagerID {
 | 
			
		||||
		resp.ERROR(c, "超级管理员不能删除")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Where("id", id).Delete(&model.AdminUser{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Enable 启用/禁用
 | 
			
		||||
func (h *ManagerHandler) Enable(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id      uint `json:"id"`
 | 
			
		||||
		Enabled bool `json:"enabled"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.AdminUser{}).Where("id", data.Id).UpdateColumn("status", data.Enabled)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ResetPass 重置密码
 | 
			
		||||
func (h *ManagerHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
	id := h.GetLoginUserId(c)
 | 
			
		||||
	if id != SuperManagerID {
 | 
			
		||||
		resp.ERROR(c, "只有超级管理员能够进行该操作")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id       int    `json:"id"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.AdminUser
 | 
			
		||||
	res := h.DB.Where("id", data.Id).First(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	password := utils.GenPassword(data.Password, user.Salt)
 | 
			
		||||
	user.Password = password
 | 
			
		||||
	res = h.DB.Updates(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,33 +1,43 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ApiKeyHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewApiKeyHandler(app *core.AppServer, db *gorm.DB) *ApiKeyHandler {
 | 
			
		||||
	h := ApiKeyHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &ApiKeyHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ApiKeyHandler) Save(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id       uint   `json:"id"`
 | 
			
		||||
		Platform string `json:"platform"`
 | 
			
		||||
		Name     string `json:"name"`
 | 
			
		||||
		Type     string `json:"type"`
 | 
			
		||||
		Value    string `json:"value"`
 | 
			
		||||
		ApiURL   string `json:"api_url"`
 | 
			
		||||
		Enabled  bool   `json:"enabled"`
 | 
			
		||||
		ProxyURL string `json:"proxy_url"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -36,11 +46,16 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	apiKey := model.ApiKey{}
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		h.db.Find(&apiKey, data.Id)
 | 
			
		||||
		h.DB.Find(&apiKey, data.Id)
 | 
			
		||||
	}
 | 
			
		||||
	apiKey.Platform = data.Platform
 | 
			
		||||
	apiKey.Value = data.Value
 | 
			
		||||
	res := h.db.Debug().Save(&apiKey)
 | 
			
		||||
	apiKey.Type = data.Type
 | 
			
		||||
	apiKey.ApiURL = data.ApiURL
 | 
			
		||||
	apiKey.Enabled = data.Enabled
 | 
			
		||||
	apiKey.ProxyURL = data.ProxyURL
 | 
			
		||||
	apiKey.Name = data.Name
 | 
			
		||||
	res := h.DB.Save(&apiKey)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
@@ -58,9 +73,20 @@ func (h *ApiKeyHandler) Save(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ApiKeyHandler) List(c *gin.Context) {
 | 
			
		||||
	status := h.GetBool(c, "status")
 | 
			
		||||
	t := h.GetTrim(c, "type")
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if status {
 | 
			
		||||
		session = session.Where("enabled", true)
 | 
			
		||||
	}
 | 
			
		||||
	if t != "" {
 | 
			
		||||
		session = session.Where("type", t)
 | 
			
		||||
	}
 | 
			
		||||
	
 | 
			
		||||
	var items []model.ApiKey
 | 
			
		||||
	var keys = make([]vo.ApiKey, 0)
 | 
			
		||||
	res := h.db.Find(&items)
 | 
			
		||||
	res := session.Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var key vo.ApiKey
 | 
			
		||||
@@ -78,15 +104,37 @@ func (h *ApiKeyHandler) List(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, keys)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ApiKeyHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
func (h *ApiKeyHandler) Set(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id    uint        `json:"id"`
 | 
			
		||||
		Filed string      `json:"filed"`
 | 
			
		||||
		Value interface{} `json:"value"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.db.Where("id = ?", id).Delete(&model.ApiKey{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.ApiKey{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ApiKeyHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	if id <= 0 {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Where("id", id).Delete(&model.ApiKey{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										46
									
								
								api/handler/admin/captcha_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								api/handler/admin/captcha_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/mojocn/base64Captcha"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CaptchaHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewCaptchaHandler(app *core.AppServer) *CaptchaHandler {
 | 
			
		||||
	return &CaptchaHandler{BaseHandler: handler.BaseHandler{App: app}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CaptchaVo struct {
 | 
			
		||||
	CaptchaId string `json:"captcha_id"`
 | 
			
		||||
	PicPath   string `json:"pic_path"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetCaptcha 获取验证码
 | 
			
		||||
func (h *CaptchaHandler) GetCaptcha(c *gin.Context) {
 | 
			
		||||
	var captchaVo CaptchaVo
 | 
			
		||||
	driver := base64Captcha.NewDriverDigit(48, 130, 4, 0.4, 10)
 | 
			
		||||
	cp := base64Captcha.NewCaptcha(driver, base64Captcha.DefaultMemStore)
 | 
			
		||||
	// b64s是图片的base64编码
 | 
			
		||||
	id, b64s, err := cp.Generate()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "生成验证码错误!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	captchaVo.CaptchaId = id
 | 
			
		||||
	captchaVo.PicPath = b64s
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, captchaVo)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										268
									
								
								api/handler/admin/chat_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										268
									
								
								api/handler/admin/chat_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,268 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ChatHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB) *ChatHandler {
 | 
			
		||||
	return &ChatHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type chatItemVo struct {
 | 
			
		||||
	Username  string      `json:"username"`
 | 
			
		||||
	UserId    uint        `json:"user_id"`
 | 
			
		||||
	ChatId    string      `json:"chat_id"`
 | 
			
		||||
	Title     string      `json:"title"`
 | 
			
		||||
	Role      vo.ChatRole `json:"role"`
 | 
			
		||||
	Model     string      `json:"model"`
 | 
			
		||||
	Token     int         `json:"token"`
 | 
			
		||||
	CreatedAt int64       `json:"created_at"`
 | 
			
		||||
	MsgNum    int         `json:"msg_num"` // 消息数量
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) List(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Title    string   `json:"title"`
 | 
			
		||||
		UserId   uint     `json:"user_id"`
 | 
			
		||||
		Model    string   `json:"model"`
 | 
			
		||||
		CreateAt []string `json:"created_time"`
 | 
			
		||||
		Page     int      `json:"page"`
 | 
			
		||||
		PageSize int      `json:"page_size"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if data.Title != "" {
 | 
			
		||||
		session = session.Where("title LIKE ?", "%"+data.Title+"%")
 | 
			
		||||
	}
 | 
			
		||||
	if data.UserId > 0 {
 | 
			
		||||
		session = session.Where("user_id = ?", data.UserId)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Model != "" {
 | 
			
		||||
		session = session.Where("model = ?", data.Model)
 | 
			
		||||
	}
 | 
			
		||||
	if len(data.CreateAt) == 2 {
 | 
			
		||||
		start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
 | 
			
		||||
		end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
 | 
			
		||||
		session = session.Where("created_at >= ? AND created_at <= ?", start, end)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.ChatItem{}).Count(&total)
 | 
			
		||||
	var items []model.ChatItem
 | 
			
		||||
	var list = make([]chatItemVo, 0)
 | 
			
		||||
	offset := (data.Page - 1) * data.PageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		userIds := make([]uint, 0)
 | 
			
		||||
		chatIds := make([]string, 0)
 | 
			
		||||
		roleIds := make([]uint, 0)
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			userIds = append(userIds, item.UserId)
 | 
			
		||||
			chatIds = append(chatIds, item.ChatId)
 | 
			
		||||
			roleIds = append(roleIds, item.RoleId)
 | 
			
		||||
		}
 | 
			
		||||
		var messages []model.ChatMessage
 | 
			
		||||
		var users []model.User
 | 
			
		||||
		var roles []model.ChatRole
 | 
			
		||||
		h.DB.Where("chat_id IN ?", chatIds).Find(&messages)
 | 
			
		||||
		h.DB.Where("id IN ?", userIds).Find(&users)
 | 
			
		||||
		h.DB.Where("id IN ?", roleIds).Find(&roles)
 | 
			
		||||
 | 
			
		||||
		tokenMap := make(map[string]int)
 | 
			
		||||
		userMap := make(map[uint]string)
 | 
			
		||||
		msgMap := make(map[string]int)
 | 
			
		||||
		roleMap := make(map[uint]vo.ChatRole)
 | 
			
		||||
		for _, msg := range messages {
 | 
			
		||||
			tokenMap[msg.ChatId] += msg.Tokens
 | 
			
		||||
			msgMap[msg.ChatId] += 1
 | 
			
		||||
		}
 | 
			
		||||
		for _, user := range users {
 | 
			
		||||
			userMap[user.Id] = user.Username
 | 
			
		||||
		}
 | 
			
		||||
		for _, r := range roles {
 | 
			
		||||
			var roleVo vo.ChatRole
 | 
			
		||||
			err := utils.CopyObject(r, &roleVo)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			roleMap[r.Id] = roleVo
 | 
			
		||||
		}
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			list = append(list, chatItemVo{
 | 
			
		||||
				UserId:    item.UserId,
 | 
			
		||||
				Username:  userMap[item.UserId],
 | 
			
		||||
				ChatId:    item.ChatId,
 | 
			
		||||
				Title:     item.Title,
 | 
			
		||||
				Model:     item.Model,
 | 
			
		||||
				Token:     tokenMap[item.ChatId],
 | 
			
		||||
				MsgNum:    msgMap[item.ChatId],
 | 
			
		||||
				Role:      roleMap[item.RoleId],
 | 
			
		||||
				CreatedAt: item.CreatedAt.Unix(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type chatMessageVo struct {
 | 
			
		||||
	Id        uint   `json:"id"`
 | 
			
		||||
	UserId    uint   `json:"user_id"`
 | 
			
		||||
	Username  string `json:"username"`
 | 
			
		||||
	Content   string `json:"content"`
 | 
			
		||||
	Type      string `json:"type"`
 | 
			
		||||
	Model     string `json:"model"`
 | 
			
		||||
	Token     int    `json:"token"`
 | 
			
		||||
	Icon      string `json:"icon"`
 | 
			
		||||
	CreatedAt int64  `json:"created_at"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Messages 读取聊天记录列表
 | 
			
		||||
func (h *ChatHandler) Messages(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		UserId   uint     `json:"user_id"`
 | 
			
		||||
		Content  string   `json:"content"`
 | 
			
		||||
		Model    string   `json:"model"`
 | 
			
		||||
		CreateAt []string `json:"created_time"`
 | 
			
		||||
		Page     int      `json:"page"`
 | 
			
		||||
		PageSize int      `json:"page_size"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if data.Content != "" {
 | 
			
		||||
		session = session.Where("content LIKE ?", "%"+data.Content+"%")
 | 
			
		||||
	}
 | 
			
		||||
	if data.UserId > 0 {
 | 
			
		||||
		session = session.Where("user_id = ?", data.UserId)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Model != "" {
 | 
			
		||||
		session = session.Where("model = ?", data.Model)
 | 
			
		||||
	}
 | 
			
		||||
	if len(data.CreateAt) == 2 {
 | 
			
		||||
		start := utils.Str2stamp(data.CreateAt[0] + " 00:00:00")
 | 
			
		||||
		end := utils.Str2stamp(data.CreateAt[1] + " 00:00:00")
 | 
			
		||||
		session = session.Where("created_at >= ? AND created_at <= ?", start, end)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.ChatMessage{}).Count(&total)
 | 
			
		||||
	var items []model.ChatMessage
 | 
			
		||||
	var list = make([]chatMessageVo, 0)
 | 
			
		||||
	offset := (data.Page - 1) * data.PageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		userIds := make([]uint, 0)
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			userIds = append(userIds, item.UserId)
 | 
			
		||||
		}
 | 
			
		||||
		var users []model.User
 | 
			
		||||
		h.DB.Where("id IN ?", userIds).Find(&users)
 | 
			
		||||
		userMap := make(map[uint]string)
 | 
			
		||||
		for _, user := range users {
 | 
			
		||||
			userMap[user.Id] = user.Username
 | 
			
		||||
		}
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			list = append(list, chatMessageVo{
 | 
			
		||||
				Id:        item.Id,
 | 
			
		||||
				UserId:    item.UserId,
 | 
			
		||||
				Username:  userMap[item.UserId],
 | 
			
		||||
				Content:   item.Content,
 | 
			
		||||
				Model:     item.Model,
 | 
			
		||||
				Token:     item.Tokens,
 | 
			
		||||
				Icon:      item.Icon,
 | 
			
		||||
				Type:      item.Type,
 | 
			
		||||
				CreatedAt: item.CreatedAt.Unix(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// History 获取聊天历史记录
 | 
			
		||||
func (h *ChatHandler) History(c *gin.Context) {
 | 
			
		||||
	chatId := c.Query("chat_id") // 会话 ID
 | 
			
		||||
	var items []model.ChatMessage
 | 
			
		||||
	var messages = make([]vo.HistoryMessage, 0)
 | 
			
		||||
	res := h.DB.Where("chat_id = ?", chatId).Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No history message")
 | 
			
		||||
		return
 | 
			
		||||
	} else {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var v vo.HistoryMessage
 | 
			
		||||
			err := utils.CopyObject(item, &v)
 | 
			
		||||
			v.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
			v.UpdatedAt = item.UpdatedAt.Unix()
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				messages = append(messages, v)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, messages)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveChat 删除对话
 | 
			
		||||
func (h *ChatHandler) RemoveChat(c *gin.Context) {
 | 
			
		||||
	chatId := h.GetTrim(c, "chat_id")
 | 
			
		||||
	if chatId == "" {
 | 
			
		||||
		resp.ERROR(c, "请传入 ChatId")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	// 删除聊天记录
 | 
			
		||||
	res := tx.Unscoped().Debug().Where("chat_id = ?", chatId).Delete(&model.ChatMessage{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "failed to remove chat message")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 删除对话
 | 
			
		||||
	res = tx.Unscoped().Where("chat_id = ?", chatId).Delete(model.ChatItem{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		tx.Rollback() // 回滚
 | 
			
		||||
		resp.ERROR(c, "failed to remove chat")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// RemoveMessage 删除聊天记录
 | 
			
		||||
func (h *ChatHandler) RemoveMessage(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	tx := h.DB.Unscoped().Where("id = ?", id).Delete(&model.ChatMessage{})
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,51 +1,73 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ChatModelHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
 | 
			
		||||
	h := ChatModelHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &ChatModelHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatModelHandler) Save(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id        uint   `json:"id"`
 | 
			
		||||
		Name      string `json:"name"`
 | 
			
		||||
		Value     string `json:"value"`
 | 
			
		||||
		Enabled   bool   `json:"enabled"`
 | 
			
		||||
		SortNum   int    `json:"sort_num"`
 | 
			
		||||
		Platform  string `json:"platform"`
 | 
			
		||||
		Weight    int    `json:"weight"`
 | 
			
		||||
		CreatedAt int64  `json:"created_at"`
 | 
			
		||||
		Id          uint    `json:"id"`
 | 
			
		||||
		Name        string  `json:"name"`
 | 
			
		||||
		Value       string  `json:"value"`
 | 
			
		||||
		Enabled     bool    `json:"enabled"`
 | 
			
		||||
		SortNum     int     `json:"sort_num"`
 | 
			
		||||
		Open        bool    `json:"open"`
 | 
			
		||||
		Platform    string  `json:"platform"`
 | 
			
		||||
		Power       int     `json:"power"`
 | 
			
		||||
		MaxTokens   int     `json:"max_tokens"`  // 最大响应长度
 | 
			
		||||
		MaxContext  int     `json:"max_context"` // 最大上下文长度
 | 
			
		||||
		Temperature float32 `json:"temperature"` // 模型温度
 | 
			
		||||
		KeyId       int     `json:"key_id,omitempty"`
 | 
			
		||||
		CreatedAt   int64   `json:"created_at"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	item := model.ChatModel{Platform: data.Platform, Name: data.Name, Value: data.Value, Enabled: data.Enabled, SortNum: data.SortNum, Weight: data.Weight}
 | 
			
		||||
	item.Id = data.Id
 | 
			
		||||
	if item.Id > 0 {
 | 
			
		||||
		item.CreatedAt = time.Unix(data.CreatedAt, 0)
 | 
			
		||||
	item := model.ChatModel{
 | 
			
		||||
		Platform:    data.Platform,
 | 
			
		||||
		Name:        data.Name,
 | 
			
		||||
		Value:       data.Value,
 | 
			
		||||
		Enabled:     data.Enabled,
 | 
			
		||||
		SortNum:     data.SortNum,
 | 
			
		||||
		Open:        data.Open,
 | 
			
		||||
		MaxTokens:   data.MaxTokens,
 | 
			
		||||
		MaxContext:  data.MaxContext,
 | 
			
		||||
		Temperature: data.Temperature,
 | 
			
		||||
		KeyId:       data.KeyId,
 | 
			
		||||
		Power:       data.Power}
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		item.Id = data.Id
 | 
			
		||||
		res = h.DB.Select("*").Omit("created_at").Updates(&item)
 | 
			
		||||
	} else {
 | 
			
		||||
		res = h.DB.Create(&item)
 | 
			
		||||
	}
 | 
			
		||||
	res := h.db.Save(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
@@ -64,7 +86,7 @@ func (h *ChatModelHandler) Save(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// List 模型列表
 | 
			
		||||
func (h *ChatModelHandler) List(c *gin.Context) {
 | 
			
		||||
	session := h.db.Session(&gorm.Session{})
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	enable := h.GetBool(c, "enable")
 | 
			
		||||
	if enable {
 | 
			
		||||
		session = session.Where("enabled", enable)
 | 
			
		||||
@@ -72,27 +94,43 @@ func (h *ChatModelHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.ChatModel
 | 
			
		||||
	var cms = make([]vo.ChatModel, 0)
 | 
			
		||||
	res := session.Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var cm vo.ChatModel
 | 
			
		||||
			err := utils.CopyObject(item, &cm)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				cm.Id = item.Id
 | 
			
		||||
				cm.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
				cm.UpdatedAt = item.UpdatedAt.Unix()
 | 
			
		||||
				cms = append(cms, cm)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.SUCCESS(c, cms)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// initialize key name
 | 
			
		||||
	keyIds := make([]int, 0)
 | 
			
		||||
	for _, v := range items {
 | 
			
		||||
		keyIds = append(keyIds, v.KeyId)
 | 
			
		||||
	}
 | 
			
		||||
	var keys []model.ApiKey
 | 
			
		||||
	keyMap := make(map[uint]string)
 | 
			
		||||
	h.DB.Where("id IN ?", keyIds).Find(&keys)
 | 
			
		||||
	for _, v := range keys {
 | 
			
		||||
		keyMap[v.Id] = v.Name
 | 
			
		||||
	}
 | 
			
		||||
	for _, item := range items {
 | 
			
		||||
		var cm vo.ChatModel
 | 
			
		||||
		err := utils.CopyObject(item, &cm)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			cm.Id = item.Id
 | 
			
		||||
			cm.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
			cm.UpdatedAt = item.UpdatedAt.Unix()
 | 
			
		||||
			cm.KeyName = keyMap[uint(item.KeyId)]
 | 
			
		||||
			cms = append(cms, cm)
 | 
			
		||||
		} else {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, cms)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatModelHandler) Enable(c *gin.Context) {
 | 
			
		||||
func (h *ChatModelHandler) Set(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id      uint `json:"id"`
 | 
			
		||||
		Enabled bool `json:"enabled"`
 | 
			
		||||
		Id    uint        `json:"id"`
 | 
			
		||||
		Filed string      `json:"filed"`
 | 
			
		||||
		Value interface{} `json:"value"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
@@ -100,7 +138,7 @@ func (h *ChatModelHandler) Enable(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.db.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
 | 
			
		||||
	res := h.DB.Model(&model.ChatModel{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
@@ -120,7 +158,7 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.db.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		res := h.DB.Model(&model.ChatModel{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
@@ -132,13 +170,15 @@ func (h *ChatModelHandler) Sort(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func (h *ChatModelHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	if id <= 0 {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.db.Where("id = ?", id).Delete(&model.ChatModel{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	res := h.DB.Where("id = ?", id).Delete(&model.ChatModel{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,27 +1,32 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ChatRoleHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
 | 
			
		||||
	h := ChatRoleHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &ChatRoleHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Save 创建或者更新某个角色
 | 
			
		||||
@@ -41,7 +46,7 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
 | 
			
		||||
	if data.CreatedAt > 0 {
 | 
			
		||||
		role.CreatedAt = time.Unix(data.CreatedAt, 0)
 | 
			
		||||
	}
 | 
			
		||||
	res := h.db.Save(&role)
 | 
			
		||||
	res := h.DB.Save(&role)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
@@ -55,12 +60,31 @@ func (h *ChatRoleHandler) Save(c *gin.Context) {
 | 
			
		||||
func (h *ChatRoleHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.ChatRole
 | 
			
		||||
	var roles = make([]vo.ChatRole, 0)
 | 
			
		||||
	res := h.db.Order("sort_num ASC").Find(&items)
 | 
			
		||||
	res := h.DB.Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No data found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// initialize model mane for role
 | 
			
		||||
	modelIds := make([]int, 0)
 | 
			
		||||
	for _, v := range items {
 | 
			
		||||
		if v.ModelId > 0 {
 | 
			
		||||
			modelIds = append(modelIds, v.ModelId)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelNameMap := make(map[int]string)
 | 
			
		||||
	if len(modelIds) > 0 {
 | 
			
		||||
		var models []model.ChatModel
 | 
			
		||||
		tx := h.DB.Where("id IN ?", modelIds).Find(&models)
 | 
			
		||||
		if tx.Error == nil {
 | 
			
		||||
			for _, m := range models {
 | 
			
		||||
				modelNameMap[int(m.Id)] = m.Name
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for _, v := range items {
 | 
			
		||||
		var role vo.ChatRole
 | 
			
		||||
		err := utils.CopyObject(v, &role)
 | 
			
		||||
@@ -68,6 +92,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
 | 
			
		||||
			role.Id = v.Id
 | 
			
		||||
			role.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
			role.UpdatedAt = v.UpdatedAt.Unix()
 | 
			
		||||
			role.ModelName = modelNameMap[role.ModelId]
 | 
			
		||||
			roles = append(roles, role)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
@@ -88,7 +113,7 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.db.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		res := h.DB.Model(&model.ChatRole{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
@@ -98,14 +123,34 @@ func (h *ChatRoleHandler) Sort(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatRoleHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	if id <= 0 {
 | 
			
		||||
func (h *ChatRoleHandler) Set(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id    uint        `json:"id"`
 | 
			
		||||
		Filed string      `json:"filed"`
 | 
			
		||||
		Value interface{} `json:"value"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.db.Where("id = ?", id).Delete(&model.ChatRole{})
 | 
			
		||||
	res := h.DB.Model(&model.ChatRole{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatRoleHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
 | 
			
		||||
	if id <= 0 {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	res := h.DB.Where("id", id).Delete(&model.ChatRole{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "删除失败!")
 | 
			
		||||
		return
 | 
			
		||||
 
 | 
			
		||||
@@ -1,49 +1,59 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ConfigHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
	levelDB *store.LevelDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
 | 
			
		||||
	h := ConfigHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB) *ConfigHandler {
 | 
			
		||||
	return &ConfigHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}, levelDB: levelDB}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ConfigHandler) Update(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Key    string                 `json:"key"`
 | 
			
		||||
		Config map[string]interface{} `json:"config"`
 | 
			
		||||
		Key    string `json:"key"`
 | 
			
		||||
		Config struct {
 | 
			
		||||
			types.SystemConfig
 | 
			
		||||
			Content string `json:"content,omitempty"`
 | 
			
		||||
			Updated bool   `json:"updated,omitempty"`
 | 
			
		||||
		} `json:"config"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	str := utils.JsonEncode(&data.Config)
 | 
			
		||||
	config := model.Config{Key: data.Key, Config: str}
 | 
			
		||||
	res := h.db.FirstOrCreate(&config, model.Config{Key: data.Key})
 | 
			
		||||
 | 
			
		||||
	value := utils.JsonEncode(&data.Config)
 | 
			
		||||
	config := model.Config{Key: data.Key, Config: value}
 | 
			
		||||
	res := h.DB.FirstOrCreate(&config, model.Config{Key: data.Key})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.Id > 0 {
 | 
			
		||||
		config.Config = str
 | 
			
		||||
		res := h.db.Updates(&config)
 | 
			
		||||
		config.Config = value
 | 
			
		||||
		res := h.DB.Updates(&config)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, res.Error.Error())
 | 
			
		||||
			return
 | 
			
		||||
@@ -51,12 +61,10 @@ func (h *ConfigHandler) Update(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
		// update config cache for AppServer
 | 
			
		||||
		var cfg model.Config
 | 
			
		||||
		h.db.Where("marker", data.Key).First(&cfg)
 | 
			
		||||
		h.DB.Where("marker", data.Key).First(&cfg)
 | 
			
		||||
		var err error
 | 
			
		||||
		if data.Key == "system" {
 | 
			
		||||
			err = utils.JsonDecode(cfg.Config, &h.App.SysConfig)
 | 
			
		||||
		} else if data.Key == "chat" {
 | 
			
		||||
			err = utils.JsonDecode(cfg.Config, &h.App.ChatConfig)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "Failed to update config cache: "+err.Error())
 | 
			
		||||
@@ -72,18 +80,18 @@ func (h *ConfigHandler) Update(c *gin.Context) {
 | 
			
		||||
func (h *ConfigHandler) Get(c *gin.Context) {
 | 
			
		||||
	key := c.Query("key")
 | 
			
		||||
	var config model.Config
 | 
			
		||||
	res := h.db.Where("marker", key).First(&config)
 | 
			
		||||
	res := h.DB.Where("marker", key).First(&config)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var m map[string]interface{}
 | 
			
		||||
	err := utils.JsonDecode(config.Config, &m)
 | 
			
		||||
	var value map[string]interface{}
 | 
			
		||||
	err := utils.JsonDecode(config.Config, &value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, m)
 | 
			
		||||
	resp.SUCCESS(c, value)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,31 +1,38 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/shopspring/decimal"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DashboardHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDashboardHandler(app *core.AppServer, db *gorm.DB) *DashboardHandler {
 | 
			
		||||
	h := DashboardHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &DashboardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type statsVo struct {
 | 
			
		||||
	Users   int64   `json:"users"`
 | 
			
		||||
	Chats   int64   `json:"chats"`
 | 
			
		||||
	Tokens  int64   `json:"tokens"`
 | 
			
		||||
	Rewards float64 `json:"rewards"`
 | 
			
		||||
	Users  int64                         `json:"users"`
 | 
			
		||||
	Chats  int64                         `json:"chats"`
 | 
			
		||||
	Tokens int                           `json:"tokens"`
 | 
			
		||||
	Income float64                       `json:"income"`
 | 
			
		||||
	Chart  map[string]map[string]float64 `json:"chart"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *DashboardHandler) Stats(c *gin.Context) {
 | 
			
		||||
@@ -34,30 +41,84 @@ func (h *DashboardHandler) Stats(c *gin.Context) {
 | 
			
		||||
	var userCount int64
 | 
			
		||||
	now := time.Now()
 | 
			
		||||
	zeroTime := time.Date(now.Year(), now.Month(), now.Day(), 0, 0, 0, 0, now.Location())
 | 
			
		||||
	res := h.db.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
 | 
			
		||||
	res := h.DB.Model(&model.User{}).Where("created_at > ?", zeroTime).Count(&userCount)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		stats.Users = userCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// new chats statistic
 | 
			
		||||
	var chatCount int64
 | 
			
		||||
	res = h.db.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
 | 
			
		||||
	res = h.DB.Model(&model.ChatItem{}).Where("created_at > ?", zeroTime).Count(&chatCount)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		stats.Chats = chatCount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// tokens took stats
 | 
			
		||||
	var tokenCount int64
 | 
			
		||||
	res = h.db.Model(&model.HistoryMessage{}).Select("sum(tokens) as total").Where("created_at > ?", zeroTime).Scan(&tokenCount)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		stats.Tokens = tokenCount
 | 
			
		||||
	var historyMessages []model.ChatMessage
 | 
			
		||||
	res = h.DB.Where("created_at > ?", zeroTime).Find(&historyMessages)
 | 
			
		||||
	for _, item := range historyMessages {
 | 
			
		||||
		stats.Tokens += item.Tokens
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// reward revenue
 | 
			
		||||
	var amount float64
 | 
			
		||||
	res = h.db.Model(&model.Reward{}).Select("sum(amount) as total").Where("created_at > ?", zeroTime).Scan(&amount)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		stats.Rewards = amount
 | 
			
		||||
	// 众筹收入
 | 
			
		||||
	var rewards []model.Reward
 | 
			
		||||
	res = h.DB.Where("created_at > ?", zeroTime).Find(&rewards)
 | 
			
		||||
	for _, item := range rewards {
 | 
			
		||||
		stats.Income += item.Amount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 订单收入
 | 
			
		||||
	var orders []model.Order
 | 
			
		||||
	res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", zeroTime).Find(&orders)
 | 
			
		||||
	for _, item := range orders {
 | 
			
		||||
		stats.Income += item.Amount
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计7天的订单的图表
 | 
			
		||||
	startDate := now.Add(-7 * 24 * time.Hour).Format("2006-01-02")
 | 
			
		||||
	var statsChart = make(map[string]map[string]float64)
 | 
			
		||||
	//// 初始化
 | 
			
		||||
	var userStatistic, historyMessagesStatistic, incomeStatistic = make(map[string]float64), make(map[string]float64), make(map[string]float64)
 | 
			
		||||
	for i := 0; i < 7; i++ {
 | 
			
		||||
		var initTime = time.Date(now.Year(), now.Month(), now.Day()-i, 0, 0, 0, 0, now.Location()).Format("2006-01-02")
 | 
			
		||||
		userStatistic[initTime] = float64(0)
 | 
			
		||||
		historyMessagesStatistic[initTime] = float64(0)
 | 
			
		||||
		incomeStatistic[initTime] = float64(0)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计用户7天增加的曲线
 | 
			
		||||
	var users []model.User
 | 
			
		||||
	res = h.DB.Model(&model.User{}).Where("created_at > ?", startDate).Find(&users)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range users {
 | 
			
		||||
			userStatistic[item.CreatedAt.Format("2006-01-02")] += 1
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计7天Token 消耗
 | 
			
		||||
	res = h.DB.Where("created_at > ?", startDate).Find(&historyMessages)
 | 
			
		||||
	for _, item := range historyMessages {
 | 
			
		||||
		historyMessagesStatistic[item.CreatedAt.Format("2006-01-02")] += float64(item.Tokens)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 浮点数相加?
 | 
			
		||||
	// 统计最近7天的众筹
 | 
			
		||||
	res = h.DB.Where("created_at > ?", startDate).Find(&rewards)
 | 
			
		||||
	for _, item := range rewards {
 | 
			
		||||
		incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计最近7天的订单
 | 
			
		||||
	res = h.DB.Where("status = ?", types.OrderPaidSuccess).Where("created_at > ?", startDate).Find(&orders)
 | 
			
		||||
	for _, item := range orders {
 | 
			
		||||
		incomeStatistic[item.CreatedAt.Format("2006-01-02")], _ = decimal.NewFromFloat(incomeStatistic[item.CreatedAt.Format("2006-01-02")]).Add(decimal.NewFromFloat(item.Amount)).Float64()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	statsChart["users"] = userStatistic
 | 
			
		||||
	statsChart["historyMessage"] = historyMessagesStatistic
 | 
			
		||||
	statsChart["orders"] = incomeStatistic
 | 
			
		||||
 | 
			
		||||
	stats.Chart = statsChart
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, stats)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										128
									
								
								api/handler/admin/function_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								api/handler/admin/function_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,128 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type FunctionHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFunctionHandler(app *core.AppServer, db *gorm.DB) *FunctionHandler {
 | 
			
		||||
	return &FunctionHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *FunctionHandler) Save(c *gin.Context) {
 | 
			
		||||
	var data vo.Function
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var f = model.Function{
 | 
			
		||||
		Id:          data.Id,
 | 
			
		||||
		Name:        data.Name,
 | 
			
		||||
		Label:       data.Label,
 | 
			
		||||
		Description: data.Description,
 | 
			
		||||
		Parameters:  utils.JsonEncode(data.Parameters),
 | 
			
		||||
		Action:      data.Action,
 | 
			
		||||
		Token:       data.Token,
 | 
			
		||||
		Enabled:     data.Enabled,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Save(&f)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "error with save data:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	data.Id = f.Id
 | 
			
		||||
	resp.SUCCESS(c, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *FunctionHandler) Set(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id    uint        `json:"id"`
 | 
			
		||||
		Filed string      `json:"filed"`
 | 
			
		||||
		Value interface{} `json:"value"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.Function{}).Where("id = ?", data.Id).Update(data.Filed, data.Value)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *FunctionHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.Function
 | 
			
		||||
	res := h.DB.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No data found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	functions := make([]vo.Function, 0)
 | 
			
		||||
	for _, v := range items {
 | 
			
		||||
		var f vo.Function
 | 
			
		||||
		err := utils.CopyObject(v, &f)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		functions = append(functions, f)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, functions)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *FunctionHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.DB.Delete(&model.Function{Id: uint(id)})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GenToken generate function api access token
 | 
			
		||||
func (h *FunctionHandler) GenToken(c *gin.Context) {
 | 
			
		||||
	// 创建 token
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"user_id": 0,
 | 
			
		||||
		"expired": 0,
 | 
			
		||||
	})
 | 
			
		||||
	tokenString, err := token.SignedString([]byte(h.App.Config.Session.SecretKey))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with generate token", err)
 | 
			
		||||
		resp.ERROR(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, tokenString)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										128
									
								
								api/handler/admin/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								api/handler/admin/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,128 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MenuHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
 | 
			
		||||
	return &MenuHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MenuHandler) Save(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id      uint   `json:"id"`
 | 
			
		||||
		Name    string `json:"name"`
 | 
			
		||||
		Icon    string `json:"icon"`
 | 
			
		||||
		URL     string `json:"url"`
 | 
			
		||||
		SortNum int    `json:"sort_num"`
 | 
			
		||||
		Enabled bool   `json:"enabled"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Save(&model.Menu{
 | 
			
		||||
		Id:      data.Id,
 | 
			
		||||
		Name:    data.Name,
 | 
			
		||||
		Icon:    data.Icon,
 | 
			
		||||
		URL:     data.URL,
 | 
			
		||||
		SortNum: data.SortNum,
 | 
			
		||||
		Enabled: data.Enabled,
 | 
			
		||||
	})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 数据列表
 | 
			
		||||
func (h *MenuHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.Menu
 | 
			
		||||
	var list = make([]vo.Menu, 0)
 | 
			
		||||
	res := h.DB.Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var product vo.Menu
 | 
			
		||||
			err := utils.CopyObject(item, &product)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				list = append(list, product)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, list)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MenuHandler) Enable(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id      uint `json:"id"`
 | 
			
		||||
		Enabled bool `json:"enabled"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.Menu{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MenuHandler) Sort(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Ids   []uint `json:"ids"`
 | 
			
		||||
		Sorts []int  `json:"sorts"`
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.DB.Model(&model.Menu{}).Where("id", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MenuHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.DB.Where("id", id).Delete(&model.Menu{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,31 +1,37 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type OrderHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
 | 
			
		||||
	h := OrderHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &OrderHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *OrderHandler) List(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		OrderNo  string   `json:"order_no"`
 | 
			
		||||
		Status   int      `json:"status"`
 | 
			
		||||
		PayTime  []string `json:"pay_time"`
 | 
			
		||||
		Page     int      `json:"page"`
 | 
			
		||||
		PageSize int      `json:"page_size"`
 | 
			
		||||
@@ -35,7 +41,7 @@ func (h *OrderHandler) List(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.db.Session(&gorm.Session{})
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if data.OrderNo != "" {
 | 
			
		||||
		session = session.Where("order_no", data.OrderNo)
 | 
			
		||||
	}
 | 
			
		||||
@@ -44,6 +50,9 @@ func (h *OrderHandler) List(c *gin.Context) {
 | 
			
		||||
		end := utils.Str2stamp(data.PayTime[1] + " 00:00:00")
 | 
			
		||||
		session = session.Where("pay_time >= ? AND pay_time <= ?", start, end)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Status >= 0 {
 | 
			
		||||
		session = session.Where("status", data.Status)
 | 
			
		||||
	}
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.Order{}).Count(&total)
 | 
			
		||||
	var items []model.Order
 | 
			
		||||
@@ -72,7 +81,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		var item model.Order
 | 
			
		||||
		res := h.db.First(&item, id)
 | 
			
		||||
		res := h.DB.First(&item, id)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "记录不存在!")
 | 
			
		||||
			return
 | 
			
		||||
@@ -83,7 +92,7 @@ func (h *OrderHandler) Remove(c *gin.Context) {
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res = h.db.Where("id = ?", id).Delete(&model.Order{})
 | 
			
		||||
		res = h.DB.Unscoped().Where("id = ?", id).Delete(&model.Order{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										84
									
								
								api/handler/admin/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										84
									
								
								api/handler/admin/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,84 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PowerLogHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
 | 
			
		||||
	return &PowerLogHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *PowerLogHandler) List(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string   `json:"username"`
 | 
			
		||||
		Type     int      `json:"type"`
 | 
			
		||||
		Model    string   `json:"model"`
 | 
			
		||||
		Date     []string `json:"date"`
 | 
			
		||||
		Page     int      `json:"page"`
 | 
			
		||||
		PageSize int      `json:"page_size"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if data.Model != "" {
 | 
			
		||||
		session = session.Where("model", data.Model)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Type > 0 {
 | 
			
		||||
		session = session.Where("type", data.Type)
 | 
			
		||||
	}
 | 
			
		||||
	if len(data.Date) == 2 {
 | 
			
		||||
		start := data.Date[0] + " 00:00:00"
 | 
			
		||||
		end := data.Date[1] + " 00:00:00"
 | 
			
		||||
		session = session.Where("created_at >= ? AND created_at <= ?", start, end)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.PowerLog{}).Count(&total)
 | 
			
		||||
	var items []model.PowerLog
 | 
			
		||||
	var list = make([]vo.PowerLog, 0)
 | 
			
		||||
	offset := (data.Page - 1) * data.PageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var log vo.PowerLog
 | 
			
		||||
			err := utils.CopyObject(item, &log)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			log.Id = item.Id
 | 
			
		||||
			log.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
			log.TypeStr = item.Type.String()
 | 
			
		||||
			list = append(list, log)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 统计消费算力总和
 | 
			
		||||
	var totalPower float64
 | 
			
		||||
	if len(data.Date) == 2 {
 | 
			
		||||
		session.Where("mark", 0).Select("SUM(amount) as total_sum").Scan(&totalPower)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"data": vo.NewPage(total, data.Page, data.PageSize, list), "stat": totalPower})
 | 
			
		||||
}
 | 
			
		||||
@@ -1,13 +1,20 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -15,13 +22,10 @@ import (
 | 
			
		||||
 | 
			
		||||
type ProductHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
 | 
			
		||||
	h := ProductHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &ProductHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ProductHandler) Save(c *gin.Context) {
 | 
			
		||||
@@ -32,7 +36,7 @@ func (h *ProductHandler) Save(c *gin.Context) {
 | 
			
		||||
		Discount  float64 `json:"discount"`
 | 
			
		||||
		Enabled   bool    `json:"enabled"`
 | 
			
		||||
		Days      int     `json:"days"`
 | 
			
		||||
		Calls     int     `json:"calls"`
 | 
			
		||||
		Power     int     `json:"power"`
 | 
			
		||||
		CreatedAt int64   `json:"created_at"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
@@ -40,12 +44,18 @@ func (h *ProductHandler) Save(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	item := model.Product{Name: data.Name, Price: data.Price, Discount: data.Discount, Days: data.Days, Calls: data.Calls, Enabled: data.Enabled}
 | 
			
		||||
	item := model.Product{
 | 
			
		||||
		Name:     data.Name,
 | 
			
		||||
		Price:    data.Price,
 | 
			
		||||
		Discount: data.Discount,
 | 
			
		||||
		Days:     data.Days,
 | 
			
		||||
		Power:    data.Power,
 | 
			
		||||
		Enabled:  data.Enabled}
 | 
			
		||||
	item.Id = data.Id
 | 
			
		||||
	if item.Id > 0 {
 | 
			
		||||
		item.CreatedAt = time.Unix(data.CreatedAt, 0)
 | 
			
		||||
	}
 | 
			
		||||
	res := h.db.Save(&item)
 | 
			
		||||
	res := h.DB.Save(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
@@ -62,16 +72,11 @@ func (h *ProductHandler) Save(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, itemVo)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 模型列表
 | 
			
		||||
// List 数据列表
 | 
			
		||||
func (h *ProductHandler) List(c *gin.Context) {
 | 
			
		||||
	session := h.db.Session(&gorm.Session{})
 | 
			
		||||
	enable := h.GetBool(c, "enable")
 | 
			
		||||
	if enable {
 | 
			
		||||
		session = session.Where("enabled", enable)
 | 
			
		||||
	}
 | 
			
		||||
	var items []model.Product
 | 
			
		||||
	var list = make([]vo.Product, 0)
 | 
			
		||||
	res := session.Order("sort_num ASC").Find(&items)
 | 
			
		||||
	res := h.DB.Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var product vo.Product
 | 
			
		||||
@@ -100,7 +105,7 @@ func (h *ProductHandler) Enable(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.db.Model(&model.Product{}).Where("id = ?", data.Id).Update("enabled", data.Enabled)
 | 
			
		||||
	res := h.DB.Model(&model.Product{}).Where("id", data.Id).UpdateColumn("enabled", data.Enabled)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
@@ -120,7 +125,7 @@ func (h *ProductHandler) Sort(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	for index, id := range data.Ids {
 | 
			
		||||
		res := h.db.Model(&model.Product{}).Where("id = ?", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		res := h.DB.Model(&model.Product{}).Where("id", id).Update("sort_num", data.Sorts[index])
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
@@ -134,7 +139,7 @@ func (h *ProductHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		res := h.db.Where("id = ?", id).Delete(&model.Product{})
 | 
			
		||||
		res := h.DB.Where("id", id).Delete(&model.Product{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
 
 | 
			
		||||
@@ -1,30 +1,35 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RewardHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
 | 
			
		||||
	h := RewardHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &RewardHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RewardHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.Reward
 | 
			
		||||
	res := h.db.Order("id DESC").Find(&items)
 | 
			
		||||
	res := h.DB.Order("id DESC").Find(&items)
 | 
			
		||||
	var rewards = make([]vo.Reward, 0)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		userIds := make([]uint, 0)
 | 
			
		||||
@@ -32,7 +37,7 @@ func (h *RewardHandler) List(c *gin.Context) {
 | 
			
		||||
			userIds = append(userIds, v.UserId)
 | 
			
		||||
		}
 | 
			
		||||
		var users []model.User
 | 
			
		||||
		h.db.Where("id IN ?", userIds).Find(&users)
 | 
			
		||||
		h.DB.Where("id IN ?", userIds).Find(&users)
 | 
			
		||||
		var userMap = make(map[uint]model.User)
 | 
			
		||||
		for _, u := range users {
 | 
			
		||||
			userMap[u.Id] = u
 | 
			
		||||
@@ -46,7 +51,7 @@ func (h *RewardHandler) List(c *gin.Context) {
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			r.Id = v.Id
 | 
			
		||||
			r.Username = userMap[v.UserId].Mobile
 | 
			
		||||
			r.Username = userMap[v.UserId].Username
 | 
			
		||||
			r.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
			r.UpdatedAt = v.UpdatedAt.Unix()
 | 
			
		||||
			rewards = append(rewards, r)
 | 
			
		||||
@@ -55,3 +60,21 @@ func (h *RewardHandler) List(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, rewards)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *RewardHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id uint
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if data.Id > 0 {
 | 
			
		||||
		res := h.DB.Where("id = ?", data.Id).Delete(&model.Reward{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										52
									
								
								api/handler/admin/upload_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										52
									
								
								api/handler/admin/upload_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,52 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UploadHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
 | 
			
		||||
	return &UploadHandler{BaseHandler: handler.BaseHandler{DB: db, App: app}, uploaderManager: manager}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *UploadHandler) Upload(c *gin.Context) {
 | 
			
		||||
	file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	userId := 0
 | 
			
		||||
	res := h.DB.Create(&model.File{
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		Name:      file.Name,
 | 
			
		||||
		ObjKey:    file.ObjKey,
 | 
			
		||||
		URL:       file.URL,
 | 
			
		||||
		Ext:       file.Ext,
 | 
			
		||||
		Size:      file.Size,
 | 
			
		||||
		CreatedAt: time.Time{},
 | 
			
		||||
	})
 | 
			
		||||
	if res.Error != nil || res.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "error with update database: "+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, file)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,42 +1,49 @@
 | 
			
		||||
package admin
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UserHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUserHandler(app *core.AppServer, db *gorm.DB) *UserHandler {
 | 
			
		||||
	h := UserHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &UserHandler{BaseHandler: handler.BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 用户列表
 | 
			
		||||
func (h *UserHandler) List(c *gin.Context) {
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	mobile := h.GetTrim(c, "mobile")
 | 
			
		||||
	username := h.GetTrim(c, "username")
 | 
			
		||||
 | 
			
		||||
	offset := (page - 1) * pageSize
 | 
			
		||||
	var items []model.User
 | 
			
		||||
	var users = make([]vo.User, 0)
 | 
			
		||||
	var total int64
 | 
			
		||||
 | 
			
		||||
	session := h.db.Session(&gorm.Session{})
 | 
			
		||||
	if mobile != "" {
 | 
			
		||||
		session = session.Where("mobile LIKE ?", "%"+mobile+"%")
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if username != "" {
 | 
			
		||||
		session = session.Where("username LIKE ?", "%"+username+"%")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session.Model(&model.User{}).Count(&total)
 | 
			
		||||
@@ -63,55 +70,77 @@ func (h *UserHandler) Save(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id          uint     `json:"id"`
 | 
			
		||||
		Password    string   `json:"password"`
 | 
			
		||||
		Mobile      string   `json:"mobile"`
 | 
			
		||||
		Calls       int      `json:"calls"`
 | 
			
		||||
		ImgCalls    int      `json:"img_calls"`
 | 
			
		||||
		Username    string   `json:"username"`
 | 
			
		||||
		ChatRoles   []string `json:"chat_roles"`
 | 
			
		||||
		ChatModels  []string `json:"chat_models"`
 | 
			
		||||
		ChatModels  []int    `json:"chat_models"`
 | 
			
		||||
		ExpiredTime string   `json:"expired_time"`
 | 
			
		||||
		Status      bool     `json:"status"`
 | 
			
		||||
		Vip         bool     `json:"vip"`
 | 
			
		||||
		Power       int      `json:"power"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user = model.User{}
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	var userVo vo.User
 | 
			
		||||
	if data.Id > 0 { // 更新
 | 
			
		||||
		user.Id = data.Id
 | 
			
		||||
		// 此处需要用 map 更新,用结构体无法更新 0 值
 | 
			
		||||
		res = h.db.Model(&user).Updates(map[string]interface{}{
 | 
			
		||||
			"mobile":           data.Mobile,
 | 
			
		||||
			"calls":            data.Calls,
 | 
			
		||||
			"img_calls":        data.ImgCalls,
 | 
			
		||||
			"status":           data.Status,
 | 
			
		||||
			"chat_roles_json":  utils.JsonEncode(data.ChatRoles),
 | 
			
		||||
			"chat_models_json": utils.JsonEncode(data.ChatModels),
 | 
			
		||||
			"expired_time":     utils.Str2stamp(data.ExpiredTime),
 | 
			
		||||
		})
 | 
			
		||||
		res = h.DB.Where("id", data.Id).First(&user)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "user not found")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		var oldPower = user.Power
 | 
			
		||||
		user.Username = data.Username
 | 
			
		||||
		user.Status = data.Status
 | 
			
		||||
		user.Vip = data.Vip
 | 
			
		||||
		user.Power = data.Power
 | 
			
		||||
		user.ChatRoles = utils.JsonEncode(data.ChatRoles)
 | 
			
		||||
		user.ChatModels = utils.JsonEncode(data.ChatModels)
 | 
			
		||||
		user.ExpiredTime = utils.Str2stamp(data.ExpiredTime)
 | 
			
		||||
 | 
			
		||||
		res = h.DB.Select("username", "status", "vip", "power", "chat_roles_json", "chat_models_json", "expired_time").Updates(&user)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 记录算力日志
 | 
			
		||||
		if oldPower != user.Power {
 | 
			
		||||
			mark := types.PowerAdd
 | 
			
		||||
			amount := user.Power - oldPower
 | 
			
		||||
			if oldPower > user.Power {
 | 
			
		||||
				mark = types.PowerSub
 | 
			
		||||
				amount = oldPower - user.Power
 | 
			
		||||
			}
 | 
			
		||||
			h.DB.Create(&model.PowerLog{
 | 
			
		||||
				UserId:    user.Id,
 | 
			
		||||
				Username:  user.Username,
 | 
			
		||||
				Type:      types.PowerGift,
 | 
			
		||||
				Amount:    amount,
 | 
			
		||||
				Balance:   user.Power,
 | 
			
		||||
				Mark:      mark,
 | 
			
		||||
				Model:     "管理员",
 | 
			
		||||
				Remark:    fmt.Sprintf("后台管理员强制修改用户算力,修改前:%d,修改后:%d, 管理员ID:%d", oldPower, user.Power, h.GetLoginUserId(c)),
 | 
			
		||||
				CreatedAt: time.Now(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		salt := utils.RandString(8)
 | 
			
		||||
		u := model.User{
 | 
			
		||||
			Mobile:      data.Mobile,
 | 
			
		||||
			Username:    data.Username,
 | 
			
		||||
			Nickname:    fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
 | 
			
		||||
			Password:    utils.GenPassword(data.Password, salt),
 | 
			
		||||
			Avatar:      "/images/avatar/user.png",
 | 
			
		||||
			Salt:        salt,
 | 
			
		||||
			Power:       data.Power,
 | 
			
		||||
			Status:      true,
 | 
			
		||||
			ChatRoles:   utils.JsonEncode(data.ChatRoles),
 | 
			
		||||
			ChatModels:  utils.JsonEncode(data.ChatModels),
 | 
			
		||||
			ExpiredTime: utils.Str2stamp(data.ExpiredTime),
 | 
			
		||||
			ChatConfig: utils.JsonEncode(types.UserChatConfig{
 | 
			
		||||
				ApiKeys: map[types.Platform]string{
 | 
			
		||||
					types.OpenAI:  "",
 | 
			
		||||
					types.Azure:   "",
 | 
			
		||||
					types.ChatGLM: "",
 | 
			
		||||
				},
 | 
			
		||||
			}),
 | 
			
		||||
			Calls:    data.Calls,
 | 
			
		||||
			ImgCalls: data.ImgCalls,
 | 
			
		||||
		}
 | 
			
		||||
		res = h.db.Create(&u)
 | 
			
		||||
		res = h.DB.Create(&u)
 | 
			
		||||
		_ = utils.CopyObject(u, &userVo)
 | 
			
		||||
		userVo.Id = u.Id
 | 
			
		||||
		userVo.CreatedAt = u.CreatedAt.Unix()
 | 
			
		||||
@@ -138,7 +167,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.db.First(&user, data.Id)
 | 
			
		||||
	res := h.DB.First(&user, data.Id)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No user found")
 | 
			
		||||
		return
 | 
			
		||||
@@ -146,7 +175,7 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	password := utils.GenPassword(data.Password, user.Salt)
 | 
			
		||||
	user.Password = password
 | 
			
		||||
	res = h.db.Updates(&user)
 | 
			
		||||
	res = h.DB.Updates(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c)
 | 
			
		||||
	} else {
 | 
			
		||||
@@ -156,36 +185,32 @@ func (h *UserHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
func (h *UserHandler) Remove(c *gin.Context) {
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	if id > 0 {
 | 
			
		||||
		tx := h.db.Begin()
 | 
			
		||||
		res := h.db.Where("id = ?", id).Delete(&model.User{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "删除失败")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 删除聊天记录
 | 
			
		||||
		res = h.db.Where("user_id = ?", id).Delete(&model.ChatItem{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, "删除失败")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 删除聊天历史记录
 | 
			
		||||
		res = h.db.Where("user_id = ?", id).Delete(&model.HistoryMessage{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, "删除失败")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 删除登录日志
 | 
			
		||||
		res = h.db.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			tx.Rollback()
 | 
			
		||||
			resp.ERROR(c, "删除失败")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		tx.Commit()
 | 
			
		||||
	if id <= 0 {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 删除用户
 | 
			
		||||
	res := h.DB.Where("id = ?", id).Delete(&model.User{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "删除失败")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 删除聊天记录
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.ChatItem{})
 | 
			
		||||
	// 删除聊天历史记录
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.ChatMessage{})
 | 
			
		||||
	// 删除登录日志
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.UserLoginLog{})
 | 
			
		||||
	// 删除算力日志
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.PowerLog{})
 | 
			
		||||
	// 删除众筹日志
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.Reward{})
 | 
			
		||||
	// 删除绘图任务
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.MidJourneyJob{})
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.SdJob{})
 | 
			
		||||
	//  删除订单
 | 
			
		||||
	h.DB.Where("user_id = ?", id).Delete(&model.Order{})
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -193,10 +218,10 @@ func (h *UserHandler) LoginLog(c *gin.Context) {
 | 
			
		||||
	page := h.GetInt(c, "page", 1)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 20)
 | 
			
		||||
	var total int64
 | 
			
		||||
	h.db.Model(&model.UserLoginLog{}).Count(&total)
 | 
			
		||||
	h.DB.Model(&model.UserLoginLog{}).Count(&total)
 | 
			
		||||
	offset := (page - 1) * pageSize
 | 
			
		||||
	var items []model.UserLoginLog
 | 
			
		||||
	res := h.db.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
 | 
			
		||||
	res := h.DB.Offset(offset).Limit(pageSize).Order("id DESC").Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "获取数据失败")
 | 
			
		||||
		return
 | 
			
		||||
 
 | 
			
		||||
@@ -1,11 +1,21 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
@@ -15,6 +25,7 @@ var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type BaseHandler struct {
 | 
			
		||||
	App *core.AppServer
 | 
			
		||||
	DB  *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) GetTrim(c *gin.Context, key string) string {
 | 
			
		||||
@@ -49,3 +60,35 @@ func (h *BaseHandler) GetUserKey(c *gin.Context) string {
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("users/%v", userId)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) GetLoginUserId(c *gin.Context) uint {
 | 
			
		||||
	userId, ok := c.Get(types.LoginUserID)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return 0
 | 
			
		||||
	}
 | 
			
		||||
	return uint(utils.IntValue(utils.InterfaceToString(userId), 0))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) IsLogin(c *gin.Context) bool {
 | 
			
		||||
	return h.GetLoginUserId(c) > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *BaseHandler) GetLoginUser(c *gin.Context) (model.User, error) {
 | 
			
		||||
	value, exists := c.Get(types.LoginUserCache)
 | 
			
		||||
	if exists {
 | 
			
		||||
		return value.(model.User), nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId, ok := c.Get(types.LoginUserID)
 | 
			
		||||
	if !ok {
 | 
			
		||||
		return model.User{}, errors.New("user not login")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.DB.First(&user, userId)
 | 
			
		||||
	// 更新缓存
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		c.Set(types.LoginUserCache, user)
 | 
			
		||||
	}
 | 
			
		||||
	return user, res.Error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,16 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -45,3 +52,33 @@ func (h *CaptchaHandler) Check(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SlideGet 获取滑动验证图片
 | 
			
		||||
func (h *CaptchaHandler) SlideGet(c *gin.Context) {
 | 
			
		||||
	data, err := h.service.SlideGet()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SlideCheck 滑动验证结果校验
 | 
			
		||||
func (h *CaptchaHandler) SlideCheck(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Key string `json:"key"`
 | 
			
		||||
		X   int    `json:"x"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if h.service.SlideCheck(data) {
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
	} else {
 | 
			
		||||
		resp.ERROR(c)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,45 +1,53 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ChatModelHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatModelHandler(app *core.AppServer, db *gorm.DB) *ChatModelHandler {
 | 
			
		||||
	h := ChatModelHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &ChatModelHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 模型列表
 | 
			
		||||
func (h *ChatModelHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.ChatModel
 | 
			
		||||
	var chatModels = make([]vo.ChatModel, 0)
 | 
			
		||||
	// 只加载用户订阅的 AI 模型
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	// 如果用户没有登录,则加载所有开放模型
 | 
			
		||||
	if !h.IsLogin(c) {
 | 
			
		||||
		res = h.DB.Where("enabled", true).Where("open", true).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	} else {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		var models []int
 | 
			
		||||
		err := utils.JsonDecode(user.ChatModels, &models)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "当前用户没有订阅任何模型")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		// 查询用户有权限访问的模型以及所有开放的模型
 | 
			
		||||
		res = h.DB.Where("enabled = ?", true).Where(
 | 
			
		||||
			h.DB.Where("id IN ?", models).Or("open", true),
 | 
			
		||||
		).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var models []string
 | 
			
		||||
	err = utils.JsonDecode(user.ChatModels, &models)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "当前用户没有订阅任何模型")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.db.Where("enabled = ?", true).Where("value IN ?", models).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var cm vo.ChatModel
 | 
			
		||||
 
 | 
			
		||||
@@ -1,12 +1,19 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -14,27 +21,26 @@ import (
 | 
			
		||||
 | 
			
		||||
type ChatRoleHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatRoleHandler(app *core.AppServer, db *gorm.DB) *ChatRoleHandler {
 | 
			
		||||
	handler := &ChatRoleHandler{db: db}
 | 
			
		||||
	handler.App = app
 | 
			
		||||
	return handler
 | 
			
		||||
	return &ChatRoleHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List get user list
 | 
			
		||||
// List 获取用户聊天应用列表
 | 
			
		||||
func (h *ChatRoleHandler) List(c *gin.Context) {
 | 
			
		||||
	all := h.GetBool(c, "all")
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var roles []model.ChatRole
 | 
			
		||||
	res := h.db.Where("enable", true).Order("sort_num ASC").Find(&roles)
 | 
			
		||||
	var roleVos = make([]vo.ChatRole, 0)
 | 
			
		||||
	res := h.DB.Where("enable", true).Order("sort_num ASC").Find(&roles)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No roles found,"+res.Error.Error())
 | 
			
		||||
		resp.SUCCESS(c, roleVos)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 获取所有角色
 | 
			
		||||
	if all {
 | 
			
		||||
	if userId == 0 || all {
 | 
			
		||||
		// 转成 vo
 | 
			
		||||
		var roleVos = make([]vo.ChatRole, 0)
 | 
			
		||||
		for _, r := range roles {
 | 
			
		||||
@@ -49,21 +55,15 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	h.db.First(&user, userId)
 | 
			
		||||
	h.DB.First(&user, userId)
 | 
			
		||||
	var roleKeys []string
 | 
			
		||||
	err := utils.JsonDecode(user.ChatRoles, &roleKeys)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "角色解析失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 转成 vo
 | 
			
		||||
	var roleVos = make([]vo.ChatRole, 0)
 | 
			
		||||
 | 
			
		||||
	for _, r := range roles {
 | 
			
		||||
		if !utils.ContainsStr(roleKeys, r.Key) {
 | 
			
		||||
			continue
 | 
			
		||||
@@ -80,7 +80,7 @@ func (h *ChatRoleHandler) List(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// UpdateRole 更新用户聊天角色
 | 
			
		||||
func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
@@ -94,7 +94,7 @@ func (h *ChatRoleHandler) UpdateRole(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.db.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
 | 
			
		||||
	res := h.DB.Model(&model.User{}).Where("id = ?", user.Id).UpdateColumn("chat_roles_json", utils.JsonEncode(data.Keys))
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("添加应用失败:", err)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
 
 | 
			
		||||
@@ -1,15 +1,23 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -19,7 +27,7 @@ import (
 | 
			
		||||
// 微软 Azure 模型消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
@@ -29,22 +37,17 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
@@ -56,9 +59,6 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var functionCall = false
 | 
			
		||||
		var functionName string
 | 
			
		||||
		var arguments = make([]string, 0)
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
@@ -68,34 +68,14 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
 | 
			
		||||
			var responseBody = types.ApiResponse{}
 | 
			
		||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
			if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
 | 
			
		||||
				logger.Error(err, line)
 | 
			
		||||
				utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
				utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
				break
 | 
			
		||||
			if err != nil { // 数据解析出错
 | 
			
		||||
				return errors.New(line)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			fun := responseBody.Choices[0].Delta.FunctionCall
 | 
			
		||||
			if functionCall && fun.Name == "" {
 | 
			
		||||
				arguments = append(arguments, fun.Arguments)
 | 
			
		||||
			if len(responseBody.Choices) == 0 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !utils.IsEmptyValue(fun) {
 | 
			
		||||
				functionName = fun.Name
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				if f != nil {
 | 
			
		||||
					functionCall = true
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 初始化 role
 | 
			
		||||
			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
 | 
			
		||||
				message.Role = responseBody.Choices[0].Delta.Role
 | 
			
		||||
@@ -121,54 +101,8 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if functionCall { // 调用函数完成任务
 | 
			
		||||
			var params map[string]interface{}
 | 
			
		||||
			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
 | 
			
		||||
			logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
 | 
			
		||||
 | 
			
		||||
			// for creating image, check if the user's img_calls > 0
 | 
			
		||||
			if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
 | 
			
		||||
				utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
 | 
			
		||||
				utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
			} else {
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				if functionName == types.FuncMidJourney {
 | 
			
		||||
					params["user_id"] = userVo.Id
 | 
			
		||||
					params["role_id"] = role.Id
 | 
			
		||||
					params["chat_id"] = session.ChatId
 | 
			
		||||
					params["icon"] = "/images/avatar/mid_journey.png"
 | 
			
		||||
					params["session_id"] = session.SessionId
 | 
			
		||||
				}
 | 
			
		||||
				data, err := f.Invoke(params)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					msg := "调用函数出错:" + err.Error()
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: msg,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, msg)
 | 
			
		||||
				} else {
 | 
			
		||||
					content := data
 | 
			
		||||
					if functionName == types.FuncMidJourney {
 | 
			
		||||
						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
			
		||||
						h.mjService.ChatClients.Put(session.SessionId, ws)
 | 
			
		||||
						// update user's img_calls
 | 
			
		||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: content,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, content)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
			h.subUserCalls(userVo, session)
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
@@ -177,77 +111,64 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext && functionCall == false {
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				useContext := true
 | 
			
		||||
				if functionCall {
 | 
			
		||||
					useContext = false
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for prompt
 | 
			
		||||
				promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.PromptMsg,
 | 
			
		||||
					Icon:       userVo.Avatar,
 | 
			
		||||
					Content:    prompt,
 | 
			
		||||
					Tokens:     promptToken,
 | 
			
		||||
					UseContext: useContext,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
				res := h.db.Save(&historyUserMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				var totalTokens = 0
 | 
			
		||||
				if functionCall { // prompt + 函数名 + 参数 token
 | 
			
		||||
					tokens, _ := utils.CalcTokens(functionName, req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
					tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
				} else {
 | 
			
		||||
					totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				}
 | 
			
		||||
				totalTokens += getTotalTokens(req)
 | 
			
		||||
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.ReplyMsg,
 | 
			
		||||
					Icon:       role.Icon,
 | 
			
		||||
					Content:    message.Content,
 | 
			
		||||
					Tokens:     totalTokens,
 | 
			
		||||
					UseContext: useContext,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
				res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 更新用户信息
 | 
			
		||||
				h.incUserTokenFee(userVo.Id, totalTokens)
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			replyTokens += getTotalTokens(req)
 | 
			
		||||
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     replyTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
@@ -258,7 +179,8 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				h.db.Create(&chatItem)
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
@@ -274,11 +196,10 @@ func (h *ChatHandler) sendAzureMessage(
 | 
			
		||||
 | 
			
		||||
		if strings.Contains(res.Error.Message, "maximum context length") {
 | 
			
		||||
			logger.Error(res.Error.Message)
 | 
			
		||||
			utils.ReplyMessage(ws, "当前会话上下文长度超出限制,已为您清空会话上下文!")
 | 
			
		||||
			h.App.ChatContexts.Delete(session.ChatId)
 | 
			
		||||
			return h.sendMessage(ctx, session, role, prompt, ws)
 | 
			
		||||
		} else {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 Azure API 失败:"+res.Error.Message)
 | 
			
		||||
			return fmt.Errorf("请求 Azure API 失败:%v", res.Error)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,14 +1,23 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
@@ -35,7 +44,7 @@ type baiduResp struct {
 | 
			
		||||
// 百度文心一言消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
@@ -45,22 +54,16 @@ func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
@@ -84,6 +87,11 @@ func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
				content = line[5:]
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 处理代码换行
 | 
			
		||||
			if len(content) == 0 {
 | 
			
		||||
				content = "\n"
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var resp baiduResp
 | 
			
		||||
			err := utils.JsonDecode(content, &resp)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
@@ -122,9 +130,6 @@ func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
			h.subUserCalls(userVo, session)
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
@@ -132,63 +137,63 @@ func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext {
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				// for prompt
 | 
			
		||||
				promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.PromptMsg,
 | 
			
		||||
					Icon:       userVo.Avatar,
 | 
			
		||||
					Content:    prompt,
 | 
			
		||||
					Tokens:     promptToken,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
				res := h.db.Save(&historyUserMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for reply
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				replyToken, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				totalTokens := replyToken + getTotalTokens(req)
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.ReplyMsg,
 | 
			
		||||
					Icon:       role.Icon,
 | 
			
		||||
					Content:    message.Content,
 | 
			
		||||
					Tokens:     totalTokens,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
				res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
				// 更新用户信息
 | 
			
		||||
				h.incUserTokenFee(userVo.Id, totalTokens)
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
@@ -199,7 +204,8 @@ func (h *ChatHandler) sendBaiduMessage(
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				h.db.Create(&chatItem)
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,56 +1,65 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const ErrorMsg = "抱歉,AI 助手开小差了,请稍后再试。"
 | 
			
		||||
const ErrImg = ""
 | 
			
		||||
 | 
			
		||||
var ErrImg = ""
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type ChatHandler struct {
 | 
			
		||||
	handler.BaseHandler
 | 
			
		||||
	db        *gorm.DB
 | 
			
		||||
	leveldb   *store.LevelDB
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	mjService *mj.Service
 | 
			
		||||
	redis         *redis.Client
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, levelDB *store.LevelDB, redis *redis.Client, service *mj.Service) *ChatHandler {
 | 
			
		||||
	h := ChatHandler{
 | 
			
		||||
		db:        db,
 | 
			
		||||
		leveldb:   levelDB,
 | 
			
		||||
		redis:     redis,
 | 
			
		||||
		mjService: service,
 | 
			
		||||
func NewChatHandler(app *core.AppServer, db *gorm.DB, redis *redis.Client, manager *oss.UploaderManager) *ChatHandler {
 | 
			
		||||
	return &ChatHandler{
 | 
			
		||||
		BaseHandler:   handler.BaseHandler{App: app, DB: db},
 | 
			
		||||
		redis:         redis,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
	}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var chatConfig types.ChatConfig
 | 
			
		||||
func (h *ChatHandler) Init() {
 | 
			
		||||
	// 如果后台有上传微信客服微信二维码,则覆盖
 | 
			
		||||
	if h.App.SysConfig.WechatCardURL != "" {
 | 
			
		||||
		ErrImg = fmt.Sprintf("", h.App.SysConfig.WechatCardURL)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ChatHandle 处理聊天 WebSocket 请求
 | 
			
		||||
func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
@@ -66,9 +75,20 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
	modelId := h.GetInt(c, "model_id", 0)
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	var chatRole model.ChatRole
 | 
			
		||||
	res := h.DB.First(&chatRole, roleId)
 | 
			
		||||
	if res.Error != nil || !chatRole.Enable {
 | 
			
		||||
		utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// if the role bind a model_id, use role's bind model_id
 | 
			
		||||
	if chatRole.ModelId > 0 {
 | 
			
		||||
		modelId = chatRole.ModelId
 | 
			
		||||
	}
 | 
			
		||||
	// get model info
 | 
			
		||||
	var chatModel model.ChatModel
 | 
			
		||||
	res := h.db.First(&chatModel, modelId)
 | 
			
		||||
	res = h.DB.First(&chatModel, modelId)
 | 
			
		||||
	if res.Error != nil || chatModel.Enabled == false {
 | 
			
		||||
		utils.ReplyMessage(client, "当前AI模型暂未启用,连接已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
@@ -77,7 +97,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	session := h.App.ChatSession.Get(sessionId)
 | 
			
		||||
	if session == nil {
 | 
			
		||||
		user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
		user, err := h.GetLoginUser(c)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Info("用户未登录")
 | 
			
		||||
			c.Abort()
 | 
			
		||||
@@ -86,7 +106,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
		session = &types.ChatSession{
 | 
			
		||||
			SessionId: sessionId,
 | 
			
		||||
			ClientIP:  c.ClientIP(),
 | 
			
		||||
			Username:  user.Mobile,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
		}
 | 
			
		||||
		h.App.ChatSession.Put(sessionId, session)
 | 
			
		||||
@@ -94,7 +114,7 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	// use old chat data override the chat model and role ID
 | 
			
		||||
	var chat model.ChatItem
 | 
			
		||||
	res = h.db.Where("chat_id=?", chatId).First(&chat)
 | 
			
		||||
	res = h.DB.Where("chat_id = ?", chatId).First(&chat)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		chatModel.Id = chat.ModelId
 | 
			
		||||
		roleId = int(chat.RoleId)
 | 
			
		||||
@@ -102,28 +122,18 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	session.ChatId = chatId
 | 
			
		||||
	session.Model = types.ChatModel{
 | 
			
		||||
		Id:       chatModel.Id,
 | 
			
		||||
		Value:    chatModel.Value,
 | 
			
		||||
		Weight:   chatModel.Weight,
 | 
			
		||||
		Platform: types.Platform(chatModel.Platform)}
 | 
			
		||||
		Id:          chatModel.Id,
 | 
			
		||||
		Name:        chatModel.Name,
 | 
			
		||||
		Value:       chatModel.Value,
 | 
			
		||||
		Power:       chatModel.Power,
 | 
			
		||||
		MaxTokens:   chatModel.MaxTokens,
 | 
			
		||||
		MaxContext:  chatModel.MaxContext,
 | 
			
		||||
		Temperature: chatModel.Temperature,
 | 
			
		||||
		KeyId:       chatModel.KeyId,
 | 
			
		||||
		Platform:    types.Platform(chatModel.Platform)}
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s, Username: %s", c.ClientIP(), session.Username)
 | 
			
		||||
	var chatRole model.ChatRole
 | 
			
		||||
	res = h.db.First(&chatRole, roleId)
 | 
			
		||||
	if res.Error != nil || !chatRole.Enable {
 | 
			
		||||
		utils.ReplyMessage(client, "当前聊天角色不存在或者未启用,连接已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 初始化聊天配置
 | 
			
		||||
	var config model.Config
 | 
			
		||||
	h.db.Where("marker", "chat").First(&config)
 | 
			
		||||
	err = utils.JsonDecode(config.Config, &chatConfig)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		utils.ReplyMessage(client, "加载系统配置失败,连接已关闭!!!")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	h.Init()
 | 
			
		||||
 | 
			
		||||
	// 保存会话连接
 | 
			
		||||
	h.App.ChatClients.Put(sessionId, client)
 | 
			
		||||
@@ -131,9 +141,10 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := client.Receive()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				logger.Debugf("close connection: %s", client.Conn.RemoteAddr())
 | 
			
		||||
				client.Close()
 | 
			
		||||
				h.App.ChatClients.Delete(sessionId)
 | 
			
		||||
				h.App.ChatSession.Delete(sessionId)
 | 
			
		||||
				cancelFunc := h.App.ReqCancelFunc.Get(sessionId)
 | 
			
		||||
				if cancelFunc != nil {
 | 
			
		||||
					cancelFunc()
 | 
			
		||||
@@ -142,19 +153,30 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			message := string(msg)
 | 
			
		||||
			logger.Info("Receive a message: ", message)
 | 
			
		||||
			//utils.ReplyMessage(client, "这是一条测试消息!")
 | 
			
		||||
			var message types.WsMessage
 | 
			
		||||
			err = utils.JsonDecode(string(msg), &message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 心跳消息
 | 
			
		||||
			if message.Type == "heartbeat" {
 | 
			
		||||
				logger.Debug("收到 Chat 心跳消息:", message.Content)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Info("Receive a message: ", message.Content)
 | 
			
		||||
 | 
			
		||||
			ctx, cancel := context.WithCancel(context.Background())
 | 
			
		||||
			h.App.ReqCancelFunc.Put(sessionId, cancel)
 | 
			
		||||
			// 回复消息
 | 
			
		||||
			err = h.sendMessage(ctx, session, chatRole, message, client)
 | 
			
		||||
			err = h.sendMessage(ctx, session, chatRole, utils.InterfaceToString(message.Content), client)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
				utils.ReplyMessage(client, err.Error())
 | 
			
		||||
			} else {
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
				logger.Info("回答完毕: " + string(message))
 | 
			
		||||
				logger.Infof("回答完毕: %v", message.Content)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
@@ -162,17 +184,18 @@ func (h *ChatHandler) ChatHandle(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSession, role model.ChatRole, prompt string, ws *types.WsClient) error {
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if r := recover(); r != nil {
 | 
			
		||||
			logger.Error("Recover message from error: ", r)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	if !h.App.Debug {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if r := recover(); r != nil {
 | 
			
		||||
				logger.Error("Recover message from error: ", r)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.db.Model(&model.User{}).First(&user, session.UserId)
 | 
			
		||||
	res := h.DB.Model(&model.User{}).First(&user, session.UserId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		utils.ReplyMessage(ws, "非法用户,请联系管理员!")
 | 
			
		||||
		return res.Error
 | 
			
		||||
		return errors.New("未授权用户,您正在进行非法操作!")
 | 
			
		||||
	}
 | 
			
		||||
	var userVo vo.User
 | 
			
		||||
	err := utils.CopyObject(user, &userVo)
 | 
			
		||||
@@ -182,105 +205,97 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userVo.Status == false {
 | 
			
		||||
		utils.ReplyMessage(ws, "您的账号已经被禁用,如果疑问,请联系管理员!")
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return nil
 | 
			
		||||
		return errors.New("您的账号已经被禁用,如果疑问,请联系管理员!")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userVo.Calls < session.Model.Weight {
 | 
			
		||||
		utils.ReplyMessage(ws, fmt.Sprintf("您当前剩余对话次数(%d)已不足以支付当前模型的单次对话需要消耗的对话额度(%d)!", userVo.Calls, session.Model.Weight))
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userVo.Calls <= 0 && userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
 | 
			
		||||
		utils.ReplyMessage(ws, "您的对话次数已经用尽,请联系管理员或者充值点卡继续对话!")
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return nil
 | 
			
		||||
	if userVo.Power < session.Model.Power {
 | 
			
		||||
		return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型的单次对话需要消耗的算力(%d)!", userVo.Power, session.Model.Power)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if userVo.ExpiredTime > 0 && userVo.ExpiredTime <= time.Now().Unix() {
 | 
			
		||||
		utils.ReplyMessage(ws, "您的账号已经过期,请联系管理员!")
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return nil
 | 
			
		||||
		return errors.New("您的账号已经过期,请联系管理员!")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查 prompt 长度是否超过了当前模型允许的最大上下文长度
 | 
			
		||||
	promptTokens, err := utils.CalcTokens(prompt, session.Model.Value)
 | 
			
		||||
	if promptTokens > session.Model.MaxContext {
 | 
			
		||||
 | 
			
		||||
		return errors.New("对话内容超出了当前模型允许的最大上下文长度!")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var req = types.ApiRequest{
 | 
			
		||||
		Model:  session.Model.Value,
 | 
			
		||||
		Stream: true,
 | 
			
		||||
	}
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.Azure.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.Azure.MaxTokens
 | 
			
		||||
	case types.Azure, types.ChatGLM, types.Baidu, types.XunFei:
 | 
			
		||||
		req.Temperature = session.Model.Temperature
 | 
			
		||||
		req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.ChatGML.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.ChatGML.MaxTokens
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.OpenAI.Temperature
 | 
			
		||||
		// TODO: 目前只支持 ERNIE-Bot-turbo 模型,如果是 ERNIE-Bot 模型则需要增加函数支持
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.OpenAI.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.OpenAI.MaxTokens
 | 
			
		||||
		req.Temperature = session.Model.Temperature
 | 
			
		||||
		req.MaxTokens = session.Model.MaxTokens
 | 
			
		||||
		// OpenAI 支持函数功能
 | 
			
		||||
		if h.App.SysConfig.EnabledFunction {
 | 
			
		||||
			var functions = make([]types.Function, 0)
 | 
			
		||||
			for _, f := range types.InnerFunctions {
 | 
			
		||||
				if !h.App.SysConfig.EnabledDraw && f.Name == types.FuncMidJourney {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				functions = append(functions, f)
 | 
			
		||||
			}
 | 
			
		||||
			req.Functions = functions
 | 
			
		||||
		var items []model.Function
 | 
			
		||||
		res := h.DB.Where("enabled", true).Find(&items)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
	case types.XunFei:
 | 
			
		||||
		req.Temperature = h.App.ChatConfig.XunFei.Temperature
 | 
			
		||||
		req.MaxTokens = h.App.ChatConfig.XunFei.MaxTokens
 | 
			
		||||
 | 
			
		||||
		var tools = make([]types.Tool, 0)
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			var parameters map[string]interface{}
 | 
			
		||||
			err = utils.JsonDecode(v.Parameters, ¶meters)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			required := parameters["required"]
 | 
			
		||||
			delete(parameters, "required")
 | 
			
		||||
			tool := types.Tool{
 | 
			
		||||
				Type: "function",
 | 
			
		||||
				Function: types.Function{
 | 
			
		||||
					Name:        v.Name,
 | 
			
		||||
					Description: v.Description,
 | 
			
		||||
					Parameters:  parameters,
 | 
			
		||||
				},
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Fixed: compatible for gpt4-turbo-xxx model
 | 
			
		||||
			if !strings.HasPrefix(req.Model, "gpt-4-turbo-") {
 | 
			
		||||
				tool.Function.Required = required
 | 
			
		||||
			}
 | 
			
		||||
			tools = append(tools, tool)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if len(tools) > 0 {
 | 
			
		||||
			req.Tools = tools
 | 
			
		||||
			req.ToolChoice = "auto"
 | 
			
		||||
		}
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
		req.Parameters = map[string]interface{}{
 | 
			
		||||
			"max_tokens":  session.Model.MaxTokens,
 | 
			
		||||
			"temperature": session.Model.Temperature,
 | 
			
		||||
		}
 | 
			
		||||
		break
 | 
			
		||||
 | 
			
		||||
	default:
 | 
			
		||||
		utils.ReplyMessage(ws, "不支持的平台:"+session.Model.Platform+",请联系管理员!")
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return nil
 | 
			
		||||
		return fmt.Errorf("不支持的平台:%s", session.Model.Platform)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 加载聊天上下文
 | 
			
		||||
	var chatCtx []interface{}
 | 
			
		||||
	if h.App.ChatConfig.EnableContext {
 | 
			
		||||
	chatCtx := make([]types.Message, 0)
 | 
			
		||||
	messages := make([]types.Message, 0)
 | 
			
		||||
	if h.App.SysConfig.EnableContext {
 | 
			
		||||
		if h.App.ChatContexts.Has(session.ChatId) {
 | 
			
		||||
			chatCtx = h.App.ChatContexts.Get(session.ChatId)
 | 
			
		||||
			messages = h.App.ChatContexts.Get(session.ChatId)
 | 
			
		||||
		} else {
 | 
			
		||||
			// calculate the tokens of current request, to prevent to exceeding the max tokens num
 | 
			
		||||
			tokens := req.MaxTokens
 | 
			
		||||
			for _, f := range types.InnerFunctions {
 | 
			
		||||
				tks, _ := utils.CalcTokens(utils.JsonEncode(f), req.Model)
 | 
			
		||||
				tokens += tks
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// loading the role context
 | 
			
		||||
			var messages []types.Message
 | 
			
		||||
			err := utils.JsonDecode(role.Context, &messages)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				for _, v := range messages {
 | 
			
		||||
					tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
					if tokens+tks >= types.ModelToTokens[req.Model] {
 | 
			
		||||
						break
 | 
			
		||||
					}
 | 
			
		||||
					tokens += tks
 | 
			
		||||
					chatCtx = append(chatCtx, v)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// loading recent chat history as chat context
 | 
			
		||||
			if chatConfig.ContextDeep > 0 {
 | 
			
		||||
				var historyMessages []model.HistoryMessage
 | 
			
		||||
				res := h.db.Debug().Where("chat_id = ? and use_context = 1", session.ChatId).Limit(chatConfig.ContextDeep).Order("id desc").Find(&historyMessages)
 | 
			
		||||
			_ = utils.JsonDecode(role.Context, &messages)
 | 
			
		||||
			if h.App.SysConfig.ContextDeep > 0 {
 | 
			
		||||
				var historyMessages []model.ChatMessage
 | 
			
		||||
				res := h.DB.Where("chat_id = ? and use_context = 1", session.ChatId).Limit(h.App.SysConfig.ContextDeep).Order("id DESC").Find(&historyMessages)
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					for i := len(historyMessages) - 1; i >= 0; i-- {
 | 
			
		||||
						msg := historyMessages[i]
 | 
			
		||||
						if tokens+msg.Tokens >= types.ModelToTokens[session.Model.Value] {
 | 
			
		||||
							break
 | 
			
		||||
						}
 | 
			
		||||
						tokens += msg.Tokens
 | 
			
		||||
						ms := types.Message{Role: "user", Content: msg.Content}
 | 
			
		||||
						if msg.Type == types.ReplyMsg {
 | 
			
		||||
							ms.Role = "assistant"
 | 
			
		||||
@@ -290,6 +305,29 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 计算当前请求的 token 总长度,确保不会超出最大上下文长度
 | 
			
		||||
		// MaxContextLength = Response + Tool + Prompt + Context
 | 
			
		||||
		tokens := req.MaxTokens // 最大响应长度
 | 
			
		||||
		tks, _ := utils.CalcTokens(utils.JsonEncode(req.Tools), req.Model)
 | 
			
		||||
		tokens += tks + promptTokens
 | 
			
		||||
 | 
			
		||||
		for _, v := range messages {
 | 
			
		||||
			tks, _ := utils.CalcTokens(v.Content, req.Model)
 | 
			
		||||
			// 上下文 token 超出了模型的最大上下文长度
 | 
			
		||||
			if tokens+tks >= session.Model.MaxContext {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 上下文的深度超出了模型的最大上下文深度
 | 
			
		||||
			if len(chatCtx) >= h.App.SysConfig.ContextDeep {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			tokens += tks
 | 
			
		||||
			chatCtx = append(chatCtx, v)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		logger.Debugf("聊天上下文:%+v", chatCtx)
 | 
			
		||||
	}
 | 
			
		||||
	reqMgs := make([]interface{}, 0)
 | 
			
		||||
@@ -297,10 +335,49 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		reqMgs = append(reqMgs, m)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
		"role":    "user",
 | 
			
		||||
		"content": prompt,
 | 
			
		||||
	})
 | 
			
		||||
	if session.Model.Platform == types.QWen {
 | 
			
		||||
		req.Input = make(map[string]interface{})
 | 
			
		||||
		reqMgs = append(reqMgs, types.Message{
 | 
			
		||||
			Role:    "user",
 | 
			
		||||
			Content: prompt,
 | 
			
		||||
		})
 | 
			
		||||
		req.Input["messages"] = reqMgs
 | 
			
		||||
	} else if session.Model.Platform == types.OpenAI { // extract image for gpt-vision model
 | 
			
		||||
		imgURLs := utils.ExtractImgURL(prompt)
 | 
			
		||||
		logger.Debugf("detected IMG: %+v", imgURLs)
 | 
			
		||||
		var content interface{}
 | 
			
		||||
		if len(imgURLs) > 0 {
 | 
			
		||||
			data := make([]interface{}, 0)
 | 
			
		||||
			text := prompt
 | 
			
		||||
			for _, v := range imgURLs {
 | 
			
		||||
				text = strings.Replace(text, v, "", 1)
 | 
			
		||||
				data = append(data, gin.H{
 | 
			
		||||
					"type": "image_url",
 | 
			
		||||
					"image_url": gin.H{
 | 
			
		||||
						"url": v,
 | 
			
		||||
					},
 | 
			
		||||
				})
 | 
			
		||||
			}
 | 
			
		||||
			data = append(data, gin.H{
 | 
			
		||||
				"type": "text",
 | 
			
		||||
				"text": text,
 | 
			
		||||
			})
 | 
			
		||||
			content = data
 | 
			
		||||
		} else {
 | 
			
		||||
			content = prompt
 | 
			
		||||
		}
 | 
			
		||||
		req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
			"role":    "user",
 | 
			
		||||
			"content": content,
 | 
			
		||||
		})
 | 
			
		||||
	} else {
 | 
			
		||||
		req.Messages = append(reqMgs, map[string]interface{}{
 | 
			
		||||
			"role":    "user",
 | 
			
		||||
			"content": prompt,
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("%+v", req.Messages)
 | 
			
		||||
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
@@ -313,20 +390,19 @@ func (h *ChatHandler) sendMessage(ctx context.Context, session *types.ChatSessio
 | 
			
		||||
		return h.sendBaiduMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	case types.XunFei:
 | 
			
		||||
		return h.sendXunFeiMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
		return h.sendQWenMessage(chatCtx, req, userVo, ctx, session, role, prompt, ws)
 | 
			
		||||
	}
 | 
			
		||||
	utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
		Type:    types.WsMiddle,
 | 
			
		||||
		Content: fmt.Sprintf("Not supported platform: %s", session.Model.Platform),
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Tokens 统计 token 数量
 | 
			
		||||
func (h *ChatHandler) Tokens(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Text  string `json:"text"`
 | 
			
		||||
		Model string `json:"model"`
 | 
			
		||||
		Text   string `json:"text"`
 | 
			
		||||
		Model  string `json:"model"`
 | 
			
		||||
		ChatId string `json:"chat_id"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -334,10 +410,10 @@ func (h *ChatHandler) Tokens(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果没有传入 text 字段,则说明是获取当前 reply 总的 token 消耗(带上下文)
 | 
			
		||||
	if data.Text == "" {
 | 
			
		||||
		var item model.HistoryMessage
 | 
			
		||||
	if data.Text == "" && data.ChatId != "" {
 | 
			
		||||
		var item model.ChatMessage
 | 
			
		||||
		userId, _ := c.Get(types.LoginUserID)
 | 
			
		||||
		res := h.db.Where("user_id = ?", userId).Last(&item)
 | 
			
		||||
		res := h.DB.Where("user_id = ?", userId).Where("chat_id = ?", data.ChatId).Last(&item)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, res.Error.Error())
 | 
			
		||||
			return
 | 
			
		||||
@@ -387,39 +463,45 @@ func (h *ChatHandler) StopGenerate(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// 发送请求到 OpenAI 服务器
 | 
			
		||||
// useOwnApiKey: 是否使用了用户自己的 API KEY
 | 
			
		||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platform types.Platform, apiKey *string) (*http.Response, error) {
 | 
			
		||||
func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, session *types.ChatSession, apiKey *model.ApiKey) (*http.Response, error) {
 | 
			
		||||
	// if the chat model bind a KEY, use it directly
 | 
			
		||||
	if session.Model.KeyId > 0 {
 | 
			
		||||
		h.DB.Debug().Where("id", session.Model.KeyId).Where("enabled", true).Find(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		h.DB.Debug().Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		return nil, errors.New("no available key, please import key")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	switch platform {
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
		md := strings.Replace(req.Model, ".", "", 1)
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.Azure.ApiURL, "{model}", md, 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", md, 1)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.ChatGML.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		req.Prompt = req.Messages // 使用 prompt 字段替代 message 字段
 | 
			
		||||
		req.Messages = nil
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.Baidu.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		apiURL = strings.Replace(apiKey.ApiURL, "{model}", req.Model, 1)
 | 
			
		||||
		break
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
		apiURL = apiKey.ApiURL
 | 
			
		||||
		req.Messages = nil
 | 
			
		||||
		break
 | 
			
		||||
	default:
 | 
			
		||||
		apiURL = h.App.ChatConfig.OpenAI.ApiURL
 | 
			
		||||
		apiURL = apiKey.ApiURL
 | 
			
		||||
	}
 | 
			
		||||
	if *apiKey == "" {
 | 
			
		||||
		var key model.ApiKey
 | 
			
		||||
		res := h.db.Where("platform = ?", platform).Order("last_used_at ASC").First(&key)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			return nil, errors.New("no available key, please import key")
 | 
			
		||||
		}
 | 
			
		||||
		// 更新 API KEY 的最后使用时间
 | 
			
		||||
		h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
		*apiKey = key.Value
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
	// 百度文心,需要串接 access_token
 | 
			
		||||
	if platform == types.Baidu {
 | 
			
		||||
		token, err := h.getBaiduToken(*apiKey)
 | 
			
		||||
	if session.Model.Platform == types.Baidu {
 | 
			
		||||
		token, err := h.getBaiduToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
@@ -427,6 +509,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
 | 
			
		||||
		apiURL = fmt.Sprintf("%s?access_token=%s", apiURL, token)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf(utils.JsonEncode(req))
 | 
			
		||||
 | 
			
		||||
	// 创建 HttpClient 请求对象
 | 
			
		||||
	var client *http.Client
 | 
			
		||||
	requestBody, err := json.Marshal(req)
 | 
			
		||||
@@ -440,9 +524,8 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
 | 
			
		||||
 | 
			
		||||
	request = request.WithContext(ctx)
 | 
			
		||||
	request.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	proxyURL := h.App.Config.ProxyURL
 | 
			
		||||
	if proxyURL != "" && platform == types.OpenAI { // 使用代理
 | 
			
		||||
		proxy, _ := url.Parse(proxyURL)
 | 
			
		||||
	if len(apiKey.ProxyURL) > 5 { // 使用代理
 | 
			
		||||
		proxy, _ := url.Parse(apiKey.ProxyURL)
 | 
			
		||||
		client = &http.Client{
 | 
			
		||||
			Transport: &http.Transport{
 | 
			
		||||
				Proxy: http.ProxyURL(proxy),
 | 
			
		||||
@@ -451,42 +534,79 @@ func (h *ChatHandler) doRequest(ctx context.Context, req types.ApiRequest, platf
 | 
			
		||||
	} else {
 | 
			
		||||
		client = http.DefaultClient
 | 
			
		||||
	}
 | 
			
		||||
	logger.Infof("Sending %s request, KEY: %s, PROXY: %s, Model: %s", platform, *apiKey, proxyURL, req.Model)
 | 
			
		||||
	switch platform {
 | 
			
		||||
	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
 | 
			
		||||
	switch session.Model.Platform {
 | 
			
		||||
	case types.Azure:
 | 
			
		||||
		request.Header.Set("api-key", *apiKey)
 | 
			
		||||
		request.Header.Set("api-key", apiKey.Value)
 | 
			
		||||
		break
 | 
			
		||||
	case types.ChatGLM:
 | 
			
		||||
		token, err := h.getChatGLMToken(*apiKey)
 | 
			
		||||
		token, err := h.getChatGLMToken(apiKey.Value)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return nil, err
 | 
			
		||||
		}
 | 
			
		||||
		logger.Info(token)
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token))
 | 
			
		||||
		break
 | 
			
		||||
	case types.Baidu:
 | 
			
		||||
		request.RequestURI = ""
 | 
			
		||||
	case types.OpenAI:
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", *apiKey))
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
		break
 | 
			
		||||
	case types.QWen:
 | 
			
		||||
		request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
		request.Header.Set("X-DashScope-SSE", "enable")
 | 
			
		||||
		break
 | 
			
		||||
	}
 | 
			
		||||
	return client.Do(request)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 扣减用户的对话次数
 | 
			
		||||
func (h *ChatHandler) subUserCalls(userVo vo.User, session *types.ChatSession) {
 | 
			
		||||
	// 仅当用户没有导入自己的 API KEY 时才进行扣减
 | 
			
		||||
	if userVo.ChatConfig.ApiKeys[session.Model.Platform] == "" {
 | 
			
		||||
		num := 1
 | 
			
		||||
		if session.Model.Weight > 0 {
 | 
			
		||||
			num = session.Model.Weight
 | 
			
		||||
		}
 | 
			
		||||
		h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("calls", gorm.Expr("calls - ?", num))
 | 
			
		||||
// 扣减用户算力
 | 
			
		||||
func (h *ChatHandler) subUserPower(userVo vo.User, session *types.ChatSession, promptTokens int, replyTokens int) {
 | 
			
		||||
	power := 1
 | 
			
		||||
	if session.Model.Power > 0 {
 | 
			
		||||
		power = session.Model.Power
 | 
			
		||||
	}
 | 
			
		||||
	res := h.DB.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("power", gorm.Expr("power - ?", power))
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		// 记录算力消费日志
 | 
			
		||||
		var u model.User
 | 
			
		||||
		h.DB.Where("id", userVo.Id).First(&u)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    userVo.Id,
 | 
			
		||||
			Username:  userVo.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Balance:   u.Power,
 | 
			
		||||
			Model:     session.Model.Value,
 | 
			
		||||
			Remark:    fmt.Sprintf("模型名称:%s, 提问长度:%d,回复长度:%d", session.Model.Name, promptTokens, replyTokens),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) incUserTokenFee(userId uint, tokens int) {
 | 
			
		||||
	h.db.Model(&model.User{}).Where("id = ?", userId).
 | 
			
		||||
		UpdateColumn("total_tokens", gorm.Expr("total_tokens + ?", tokens))
 | 
			
		||||
	h.db.Model(&model.User{}).Where("id = ?", userId).
 | 
			
		||||
		UpdateColumn("tokens", gorm.Expr("tokens + ?", tokens))
 | 
			
		||||
// 将AI回复消息中生成的图片链接下载到本地
 | 
			
		||||
func (h *ChatHandler) extractImgUrl(text string) string {
 | 
			
		||||
	pattern := `!\[([^\]]*)]\(([^)]+)\)`
 | 
			
		||||
	re := regexp.MustCompile(pattern)
 | 
			
		||||
	matches := re.FindAllStringSubmatch(text, -1)
 | 
			
		||||
 | 
			
		||||
	// 下载图片并替换链接地址
 | 
			
		||||
	for _, match := range matches {
 | 
			
		||||
		imageURL := match[2]
 | 
			
		||||
		logger.Debug(imageURL)
 | 
			
		||||
		// 对于相同地址的图片,已经被替换了,就不再重复下载了
 | 
			
		||||
		if !strings.Contains(text, imageURL) {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		newImgURL, err := h.uploadManager.GetUploadHandler().PutImg(imageURL, false)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		text = strings.ReplaceAll(text, imageURL, newImgURL)
 | 
			
		||||
	}
 | 
			
		||||
	return text
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,32 +1,41 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// List 获取会话列表
 | 
			
		||||
func (h *ChatHandler) List(c *gin.Context) {
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		resp.ERROR(c, "The parameter 'user_id' is needed.")
 | 
			
		||||
	if !h.IsLogin(c) {
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var items = make([]vo.ChatItem, 0)
 | 
			
		||||
	var chats []model.ChatItem
 | 
			
		||||
	res := h.db.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
 | 
			
		||||
	res := h.DB.Where("user_id = ?", userId).Order("id DESC").Find(&chats)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		var roleIds = make([]uint, 0)
 | 
			
		||||
		for _, chat := range chats {
 | 
			
		||||
			roleIds = append(roleIds, chat.RoleId)
 | 
			
		||||
		}
 | 
			
		||||
		var roles []model.ChatRole
 | 
			
		||||
		res = h.db.Find(&roles, roleIds)
 | 
			
		||||
		res = h.DB.Find(&roles, roleIds)
 | 
			
		||||
		if res.Error == nil {
 | 
			
		||||
			roleMap := make(map[uint]model.ChatRole)
 | 
			
		||||
			for _, role := range roles {
 | 
			
		||||
@@ -58,7 +67,7 @@ func (h *ChatHandler) Update(c *gin.Context) {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	res := h.db.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
 | 
			
		||||
	res := h.DB.Model(&model.ChatItem{}).Where("chat_id = ?", data.ChatId).UpdateColumn("title", data.Title)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Failed to update database")
 | 
			
		||||
		return
 | 
			
		||||
@@ -70,14 +79,14 @@ func (h *ChatHandler) Update(c *gin.Context) {
 | 
			
		||||
// Clear 清空所有聊天记录
 | 
			
		||||
func (h *ChatHandler) Clear(c *gin.Context) {
 | 
			
		||||
	// 获取当前登录用户所有的聊天会话
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var chats []model.ChatItem
 | 
			
		||||
	res := h.db.Where("user_id = ?", user.Id).Find(&chats)
 | 
			
		||||
	res := h.DB.Where("user_id = ?", user.Id).Find(&chats)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No chats found")
 | 
			
		||||
		return
 | 
			
		||||
@@ -89,13 +98,13 @@ func (h *ChatHandler) Clear(c *gin.Context) {
 | 
			
		||||
		// 清空会话上下文
 | 
			
		||||
		h.App.ChatContexts.Delete(chat.ChatId)
 | 
			
		||||
	}
 | 
			
		||||
	err = h.db.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		res := h.db.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
 | 
			
		||||
	err = h.DB.Transaction(func(tx *gorm.DB) error {
 | 
			
		||||
		res := h.DB.Where("user_id =?", user.Id).Delete(&model.ChatItem{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			return res.Error
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		res = h.db.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.HistoryMessage{})
 | 
			
		||||
		res = h.DB.Where("user_id = ? AND chat_id IN ?", user.Id, chatIds).Delete(&model.ChatMessage{})
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			return res.Error
 | 
			
		||||
		}
 | 
			
		||||
@@ -116,9 +125,9 @@ func (h *ChatHandler) Clear(c *gin.Context) {
 | 
			
		||||
// History 获取聊天历史记录
 | 
			
		||||
func (h *ChatHandler) History(c *gin.Context) {
 | 
			
		||||
	chatId := c.Query("chat_id") // 会话 ID
 | 
			
		||||
	var items []model.HistoryMessage
 | 
			
		||||
	var items []model.ChatMessage
 | 
			
		||||
	var messages = make([]vo.HistoryMessage, 0)
 | 
			
		||||
	res := h.db.Where("chat_id = ?", chatId).Find(&items)
 | 
			
		||||
	res := h.DB.Where("chat_id = ?", chatId).Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No history message")
 | 
			
		||||
		return
 | 
			
		||||
@@ -144,20 +153,20 @@ func (h *ChatHandler) Remove(c *gin.Context) {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.db.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
 | 
			
		||||
	res := h.DB.Where("user_id = ? AND chat_id = ?", user.Id, chatId).Delete(&model.ChatItem{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Failed to update database")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 删除当前会话的聊天记录
 | 
			
		||||
	res = h.db.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
 | 
			
		||||
	res = h.DB.Where("user_id = ? AND chat_id =?", user.Id, chatId).Delete(&model.ChatItem{})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Failed to remove chat from database.")
 | 
			
		||||
		return
 | 
			
		||||
@@ -179,18 +188,26 @@ func (h *ChatHandler) Detail(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var chatItem model.ChatItem
 | 
			
		||||
	res := h.db.Where("chat_id = ?", chatId).First(&chatItem)
 | 
			
		||||
	res := h.DB.Where("chat_id = ?", chatId).First(&chatItem)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "No chat found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 填充角色名称
 | 
			
		||||
	var role model.ChatRole
 | 
			
		||||
	res = h.DB.Where("id", chatItem.RoleId).First(&role)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Role not found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var chatItemVo vo.ChatItem
 | 
			
		||||
	err := utils.CopyObject(chatItem, &chatItemVo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	chatItemVo.RoleName = role.Name
 | 
			
		||||
	resp.SUCCESS(c, chatItemVo)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,15 +1,24 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
@@ -19,7 +28,7 @@ import (
 | 
			
		||||
// 清华大学 ChatGML 消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
@@ -29,22 +38,16 @@ func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
@@ -71,6 +74,10 @@ func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
			if strings.HasPrefix(line, "data:") {
 | 
			
		||||
				content = line[5:]
 | 
			
		||||
			}
 | 
			
		||||
			// 处理代码换行
 | 
			
		||||
			if len(content) == 0 {
 | 
			
		||||
				content = "\n"
 | 
			
		||||
			}
 | 
			
		||||
			switch event {
 | 
			
		||||
			case "add":
 | 
			
		||||
				if len(contents) == 0 {
 | 
			
		||||
@@ -102,9 +109,6 @@ func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
			h.subUserCalls(userVo, session)
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
@@ -112,63 +116,64 @@ func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext {
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				// for prompt
 | 
			
		||||
				promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.PromptMsg,
 | 
			
		||||
					Icon:       userVo.Avatar,
 | 
			
		||||
					Content:    prompt,
 | 
			
		||||
					Tokens:     promptToken,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
				res := h.db.Save(&historyUserMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for reply
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				replyToken, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				totalTokens := replyToken + getTotalTokens(req)
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.ReplyMsg,
 | 
			
		||||
					Icon:       role.Icon,
 | 
			
		||||
					Content:    message.Content,
 | 
			
		||||
					Tokens:     totalTokens,
 | 
			
		||||
					UseContext: true,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
				res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
				// 更新用户信息
 | 
			
		||||
				h.incUserTokenFee(userVo.Id, totalTokens)
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
@@ -179,7 +184,8 @@ func (h *ChatHandler) sendChatGLMMessage(
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				h.db.Create(&chatItem)
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,24 +1,34 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
 | 
			
		||||
	req2 "github.com/imroc/req/v3"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// OPenAI 消息发送实现
 | 
			
		||||
func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
@@ -28,22 +38,20 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session.Model.Platform, &apiKey)
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		} else {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
		utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
		utils.ReplyMessage(ws, err.Error())
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
@@ -55,10 +63,11 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		var functionCall = false
 | 
			
		||||
		var functionName string
 | 
			
		||||
		var function model.Function
 | 
			
		||||
		var toolCall = false
 | 
			
		||||
		var arguments = make([]string, 0)
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		var isNew = true
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if !strings.Contains(line, "data:") || len(line) < 30 {
 | 
			
		||||
@@ -67,44 +76,64 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
 | 
			
		||||
			var responseBody = types.ApiResponse{}
 | 
			
		||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
			if err != nil || len(responseBody.Choices) == 0 { // 数据解析出错
 | 
			
		||||
				logger.Error(err, line)
 | 
			
		||||
				utils.ReplyMessage(ws, ErrorMsg)
 | 
			
		||||
				utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
			if err != nil { // 数据解析出错
 | 
			
		||||
				return errors.New(line)
 | 
			
		||||
			}
 | 
			
		||||
			if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "stop" && len(contents) == 0 {
 | 
			
		||||
				utils.ReplyMessage(ws, "抱歉😔😔😔,AI助手由于未知原因已经停止输出内容。")
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var tool types.ToolCall
 | 
			
		||||
			if len(responseBody.Choices[0].Delta.ToolCalls) > 0 {
 | 
			
		||||
				tool = responseBody.Choices[0].Delta.ToolCalls[0]
 | 
			
		||||
				if toolCall && tool.Function.Name == "" {
 | 
			
		||||
					arguments = append(arguments, tool.Function.Arguments)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 兼容 Function Call
 | 
			
		||||
			fun := responseBody.Choices[0].Delta.FunctionCall
 | 
			
		||||
			if functionCall && fun.Name == "" {
 | 
			
		||||
			if fun.Name != "" {
 | 
			
		||||
				tool = *new(types.ToolCall)
 | 
			
		||||
				tool.Function.Name = fun.Name
 | 
			
		||||
			} else if toolCall {
 | 
			
		||||
				arguments = append(arguments, fun.Arguments)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !utils.IsEmptyValue(fun) {
 | 
			
		||||
				functionName = fun.Name
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				if f != nil {
 | 
			
		||||
					functionCall = true
 | 
			
		||||
			if !utils.IsEmptyValue(tool) {
 | 
			
		||||
				res := h.DB.Where("name = ?", tool.Function.Name).First(&function)
 | 
			
		||||
				if res.Error == nil {
 | 
			
		||||
					toolCall = true
 | 
			
		||||
					callMsg := fmt.Sprintf("正在调用工具 `%s` 作答 ...\n\n", function.Label)
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: fmt.Sprintf("正在调用函数 `%s` 作答 ...\n\n", f.Name())})
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsMiddle, Content: callMsg})
 | 
			
		||||
					contents = append(contents, callMsg)
 | 
			
		||||
				}
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "tool_calls" ||
 | 
			
		||||
				responseBody.Choices[0].FinishReason == "function_call" { // 函数调用完毕
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 初始化 role
 | 
			
		||||
			if responseBody.Choices[0].Delta.Role != "" && message.Role == "" {
 | 
			
		||||
				message.Role = responseBody.Choices[0].Delta.Role
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				continue
 | 
			
		||||
			} else if responseBody.Choices[0].FinishReason != "" {
 | 
			
		||||
			// output stopped
 | 
			
		||||
			if responseBody.Choices[0].FinishReason != "" {
 | 
			
		||||
				break // 输出完成或者输出中断了
 | 
			
		||||
			} else {
 | 
			
		||||
				content := responseBody.Choices[0].Delta.Content
 | 
			
		||||
				contents = append(contents, utils.InterfaceToString(content))
 | 
			
		||||
				if isNew {
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					isNew = false
 | 
			
		||||
				}
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
@@ -120,55 +149,40 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if functionCall { // 调用函数完成任务
 | 
			
		||||
		if toolCall { // 调用函数完成任务
 | 
			
		||||
			var params map[string]interface{}
 | 
			
		||||
			_ = utils.JsonDecode(strings.Join(arguments, ""), ¶ms)
 | 
			
		||||
			logger.Debugf("函数名称: %s, 函数参数:%s", functionName, params)
 | 
			
		||||
 | 
			
		||||
			// for creating image, check if the user's img_calls > 0
 | 
			
		||||
			if functionName == types.FuncMidJourney && userVo.ImgCalls <= 0 {
 | 
			
		||||
				utils.ReplyMessage(ws, "**当前用户剩余绘图次数已用尽,请扫描下面二维码联系管理员!**")
 | 
			
		||||
				utils.ReplyMessage(ws, ErrImg)
 | 
			
		||||
			logger.Debugf("函数名称: %s, 函数参数:%s", function.Name, params)
 | 
			
		||||
			params["user_id"] = userVo.Id
 | 
			
		||||
			var apiRes types.BizVo
 | 
			
		||||
			r, err := req2.C().R().SetHeader("Content-Type", "application/json").
 | 
			
		||||
				SetHeader("Authorization", function.Token).
 | 
			
		||||
				SetBody(params).
 | 
			
		||||
				SetSuccessResult(&apiRes).Post(function.Action)
 | 
			
		||||
			errMsg := ""
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				errMsg = err.Error()
 | 
			
		||||
			} else if r.IsErrorState() {
 | 
			
		||||
				errMsg = r.Status
 | 
			
		||||
			}
 | 
			
		||||
			if errMsg != "" || apiRes.Code != types.Success {
 | 
			
		||||
				msg := "调用函数工具出错:" + apiRes.Message + errMsg
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: msg,
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, msg)
 | 
			
		||||
			} else {
 | 
			
		||||
				f := h.App.Functions[functionName]
 | 
			
		||||
				if functionName == types.FuncMidJourney {
 | 
			
		||||
					params["user_id"] = userVo.Id
 | 
			
		||||
					params["role_id"] = role.Id
 | 
			
		||||
					params["chat_id"] = session.ChatId
 | 
			
		||||
					params["icon"] = "/images/avatar/mid_journey.png"
 | 
			
		||||
					params["session_id"] = session.SessionId
 | 
			
		||||
				}
 | 
			
		||||
				data, err := f.Invoke(params)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					msg := "调用函数出错:" + err.Error()
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: msg,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, msg)
 | 
			
		||||
				} else {
 | 
			
		||||
					content := data
 | 
			
		||||
					if functionName == types.FuncMidJourney {
 | 
			
		||||
						content = fmt.Sprintf("绘画提示词:%s 已推送任务到 MidJourney 机器人,请耐心等待任务执行...", data)
 | 
			
		||||
						h.mjService.ChatClients.Put(session.SessionId, ws)
 | 
			
		||||
						// update user's img_calls
 | 
			
		||||
						h.db.Model(&model.User{}).Where("id = ?", userVo.Id).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
					}
 | 
			
		||||
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
						Type:    types.WsMiddle,
 | 
			
		||||
						Content: content,
 | 
			
		||||
					})
 | 
			
		||||
					contents = append(contents, content)
 | 
			
		||||
				}
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: apiRes.Data,
 | 
			
		||||
				})
 | 
			
		||||
				contents = append(contents, utils.InterfaceToString(apiRes.Data))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			// 更新用户的对话次数
 | 
			
		||||
			h.subUserCalls(userVo, session)
 | 
			
		||||
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
@@ -176,77 +190,77 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.ChatConfig.EnableContext && functionCall == false {
 | 
			
		||||
			if h.App.SysConfig.EnableContext && toolCall == false {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
				useContext := true
 | 
			
		||||
				if functionCall {
 | 
			
		||||
					useContext = false
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// for prompt
 | 
			
		||||
				promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error(err)
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.PromptMsg,
 | 
			
		||||
					Icon:       userVo.Avatar,
 | 
			
		||||
					Content:    prompt,
 | 
			
		||||
					Tokens:     promptToken,
 | 
			
		||||
					UseContext: useContext,
 | 
			
		||||
				}
 | 
			
		||||
				historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
				historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
				res := h.db.Save(&historyUserMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 计算本次对话消耗的总 token 数量
 | 
			
		||||
				var totalTokens = 0
 | 
			
		||||
				if functionCall { // prompt + 函数名 + 参数 token
 | 
			
		||||
					tokens, _ := utils.CalcTokens(functionName, req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
					tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
 | 
			
		||||
					totalTokens += tokens
 | 
			
		||||
				} else {
 | 
			
		||||
					totalTokens, _ = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
				}
 | 
			
		||||
				totalTokens += getTotalTokens(req)
 | 
			
		||||
 | 
			
		||||
				historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
					UserId:     userVo.Id,
 | 
			
		||||
					ChatId:     session.ChatId,
 | 
			
		||||
					RoleId:     role.Id,
 | 
			
		||||
					Type:       types.ReplyMsg,
 | 
			
		||||
					Icon:       role.Icon,
 | 
			
		||||
					Content:    message.Content,
 | 
			
		||||
					Tokens:     totalTokens,
 | 
			
		||||
					UseContext: useContext,
 | 
			
		||||
				}
 | 
			
		||||
				historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
				historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
				res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
				if res.Error != nil {
 | 
			
		||||
					logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				// 更新用户信息
 | 
			
		||||
				h.incUserTokenFee(userVo.Id, totalTokens)
 | 
			
		||||
			useContext := true
 | 
			
		||||
			if toolCall {
 | 
			
		||||
				useContext = false
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: useContext,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			var replyTokens = 0
 | 
			
		||||
			if toolCall { // prompt + 函数名 + 参数 token
 | 
			
		||||
				tokens, _ := utils.CalcTokens(function.Name, req.Model)
 | 
			
		||||
				replyTokens += tokens
 | 
			
		||||
				tokens, _ = utils.CalcTokens(utils.InterfaceToString(arguments), req.Model)
 | 
			
		||||
				replyTokens += tokens
 | 
			
		||||
			} else {
 | 
			
		||||
				replyTokens, _ = utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			}
 | 
			
		||||
			replyTokens += getTotalTokens(req)
 | 
			
		||||
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    h.extractImgUrl(message.Content),
 | 
			
		||||
				Tokens:     replyTokens,
 | 
			
		||||
				UseContext: useContext,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
@@ -257,17 +271,20 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				h.db.Create(&chatItem)
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:"+err.Error())
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		var res types.ApiError
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:\n"+"```\n"+string(body)+"```")
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
@@ -275,7 +292,7 @@ func (h *ChatHandler) sendOpenAiMessage(
 | 
			
		||||
		if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
 | 
			
		||||
			// 移除当前 API key
 | 
			
		||||
			h.db.Where("value = ?", apiKey).Delete(&model.ApiKey{})
 | 
			
		||||
			h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
 | 
			
		||||
		} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
 | 
			
		||||
			utils.ReplyMessage(ws, "请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
 | 
			
		||||
		} else if strings.Contains(res.Error.Message, "This model's maximum context length") {
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										243
									
								
								api/handler/chatimpl/qwen_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										243
									
								
								api/handler/chatimpl/qwen_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,243 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/syndtr/goleveldb/leveldb/errors"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
	"unicode/utf8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type qWenResp struct {
 | 
			
		||||
	Output struct {
 | 
			
		||||
		FinishReason string `json:"finish_reason"`
 | 
			
		||||
		Text         string `json:"text"`
 | 
			
		||||
	} `json:"output,omitempty"`
 | 
			
		||||
	Usage struct {
 | 
			
		||||
		TotalTokens  int `json:"total_tokens"`
 | 
			
		||||
		InputTokens  int `json:"input_tokens"`
 | 
			
		||||
		OutputTokens int `json:"output_tokens"`
 | 
			
		||||
	} `json:"usage,omitempty"`
 | 
			
		||||
	RequestID string `json:"request_id"`
 | 
			
		||||
 | 
			
		||||
	Code    string `json:"code,omitempty"`
 | 
			
		||||
	Message string `json:"message,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 通义千问消息发送实现
 | 
			
		||||
func (h *ChatHandler) sendQWenMessage(
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
	session *types.ChatSession,
 | 
			
		||||
	role model.ChatRole,
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	start := time.Now()
 | 
			
		||||
	var apiKey = model.ApiKey{}
 | 
			
		||||
	response, err := h.doRequest(ctx, req, session, &apiKey)
 | 
			
		||||
	logger.Info("HTTP请求完成,耗时:", time.Now().Sub(start))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
			logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			return nil
 | 
			
		||||
		} else if strings.Contains(err.Error(), "no available key") {
 | 
			
		||||
			return errors.New("抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		}
 | 
			
		||||
		return err
 | 
			
		||||
	} else {
 | 
			
		||||
		defer response.Body.Close()
 | 
			
		||||
	}
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		replyCreatedAt := time.Now() // 记录回复时间
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		var message = types.Message{}
 | 
			
		||||
		var contents = make([]string, 0)
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
 | 
			
		||||
		var content, lastText, newText string
 | 
			
		||||
		var outPutStart = false
 | 
			
		||||
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if len(line) < 5 || strings.HasPrefix(line, "id:") ||
 | 
			
		||||
				strings.HasPrefix(line, "event:") || strings.HasPrefix(line, ":HTTP_STATUS/200") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if !strings.HasPrefix(line, "data:") {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			content = line[5:]
 | 
			
		||||
			var resp qWenResp
 | 
			
		||||
			if len(contents) == 0 { // 发送消息头
 | 
			
		||||
				if !outPutStart {
 | 
			
		||||
					utils.ReplyChunkMessage(ws, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
					outPutStart = true
 | 
			
		||||
					continue
 | 
			
		||||
				} else {
 | 
			
		||||
					// 处理代码换行
 | 
			
		||||
					content = "\n"
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				err := utils.JsonDecode(content, &resp)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error("error with parse data line: ", content)
 | 
			
		||||
					utils.ReplyMessage(ws, fmt.Sprintf("**解析数据行失败:%s**", err))
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
				if resp.Message != "" {
 | 
			
		||||
					utils.ReplyMessage(ws, fmt.Sprintf("**API 返回错误:%s**", resp.Message))
 | 
			
		||||
					break
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			//通过比较 lastText(上一次的文本)和 currentText(当前的文本),
 | 
			
		||||
			//提取出新添加的文本部分。然后只将这部分新文本发送到客户端。
 | 
			
		||||
			//每次循环结束后,lastText 会更新为当前的完整文本,以便于下一次循环进行比较。
 | 
			
		||||
			currentText := resp.Output.Text
 | 
			
		||||
			if currentText != lastText {
 | 
			
		||||
				// 提取新增文本
 | 
			
		||||
				newText = strings.Replace(currentText, lastText, "", 1)
 | 
			
		||||
				utils.ReplyChunkMessage(ws, types.WsMessage{
 | 
			
		||||
					Type:    types.WsMiddle,
 | 
			
		||||
					Content: utils.InterfaceToString(newText),
 | 
			
		||||
				})
 | 
			
		||||
				lastText = currentText // 更新 lastText
 | 
			
		||||
			}
 | 
			
		||||
			contents = append(contents, newText)
 | 
			
		||||
 | 
			
		||||
			if resp.Output.FinishReason == "stop" {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		} //end for
 | 
			
		||||
 | 
			
		||||
		if err := scanner.Err(); err != nil {
 | 
			
		||||
			if strings.Contains(err.Error(), "context canceled") {
 | 
			
		||||
				logger.Info("用户取消了请求:", prompt)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error("信息读取出错:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 消息发送成功
 | 
			
		||||
		if len(contents) > 0 {
 | 
			
		||||
			if message.Role == "" {
 | 
			
		||||
				message.Role = "assistant"
 | 
			
		||||
			}
 | 
			
		||||
			message.Content = strings.Join(contents, "")
 | 
			
		||||
			useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
			// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
			if h.App.SysConfig.EnableContext {
 | 
			
		||||
				chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
				chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
				h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 追加聊天记录
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.ChatMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
				Model:      req.Model,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 更新用户算力
 | 
			
		||||
			h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
			// 保存当前会话
 | 
			
		||||
			var chatItem model.ChatItem
 | 
			
		||||
			res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				chatItem.ChatId = session.ChatId
 | 
			
		||||
				chatItem.UserId = session.UserId
 | 
			
		||||
				chatItem.RoleId = role.Id
 | 
			
		||||
				chatItem.ModelId = session.Model.Id
 | 
			
		||||
				if utf8.RuneCountInString(prompt) > 30 {
 | 
			
		||||
					chatItem.Title = string([]rune(prompt)[:30]) + "..."
 | 
			
		||||
				} else {
 | 
			
		||||
					chatItem.Title = prompt
 | 
			
		||||
				}
 | 
			
		||||
				chatItem.Model = req.Model
 | 
			
		||||
				h.DB.Create(&chatItem)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var res struct {
 | 
			
		||||
			Code int    `json:"error_code"`
 | 
			
		||||
			Msg  string `json:"error_msg"`
 | 
			
		||||
		}
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		utils.ReplyMessage(ws, "请求通义千问大模型 API 失败:"+res.Msg)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
@@ -1,17 +1,26 @@
 | 
			
		||||
package chatimpl
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"crypto/hmac"
 | 
			
		||||
	"crypto/sha256"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"html/template"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
@@ -48,10 +57,17 @@ type xunFeiResp struct {
 | 
			
		||||
	} `json:"payload"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var Model2URL = map[string]string{
 | 
			
		||||
	"general":     "v1.1",
 | 
			
		||||
	"generalv2":   "v2.1",
 | 
			
		||||
	"generalv3":   "v3.1",
 | 
			
		||||
	"generalv3.5": "v3.5",
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 科大讯飞消息发送实现
 | 
			
		||||
 | 
			
		||||
func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
	chatCtx []interface{},
 | 
			
		||||
	chatCtx []types.Message,
 | 
			
		||||
	req types.ApiRequest,
 | 
			
		||||
	userVo vo.User,
 | 
			
		||||
	ctx context.Context,
 | 
			
		||||
@@ -60,35 +76,34 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
	prompt string,
 | 
			
		||||
	ws *types.WsClient) error {
 | 
			
		||||
	promptCreatedAt := time.Now() // 记录提问时间
 | 
			
		||||
	var apiKey = userVo.ChatConfig.ApiKeys[session.Model.Platform]
 | 
			
		||||
	if apiKey == "" {
 | 
			
		||||
		var key model.ApiKey
 | 
			
		||||
		res := h.db.Where("platform = ?", session.Model.Platform).Order("last_used_at ASC").First(&key)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
			return nil
 | 
			
		||||
		}
 | 
			
		||||
		// 更新 API KEY 的最后使用时间
 | 
			
		||||
		h.db.Model(&key).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
		apiKey = key.Value
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	// use the bind key
 | 
			
		||||
	if session.Model.KeyId > 0 {
 | 
			
		||||
		res = h.DB.Where("id", session.Model.KeyId).Where("enabled", true).Find(&apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		res = h.DB.Where("platform", session.Model.Platform).Where("type", "chat").Where("enabled", true).Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		utils.ReplyMessage(ws, "抱歉😔😔😔,系统已经没有可用的 API KEY,请联系管理员!")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.DB.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
 | 
			
		||||
	d := websocket.Dialer{
 | 
			
		||||
		HandshakeTimeout: 5 * time.Second,
 | 
			
		||||
	}
 | 
			
		||||
	key := strings.Split(apiKey, "|")
 | 
			
		||||
	key := strings.Split(apiKey.Value, "|")
 | 
			
		||||
	if len(key) != 3 {
 | 
			
		||||
		utils.ReplyMessage(ws, "非法的 API KEY!")
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var apiURL string
 | 
			
		||||
	if req.Model == "generalv2" {
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v2.1", 1)
 | 
			
		||||
	} else {
 | 
			
		||||
		apiURL = strings.Replace(h.App.ChatConfig.XunFei.ApiURL, "{version}", "v1.1", 1)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	apiURL := strings.Replace(apiKey.ApiURL, "{version}", Model2URL[req.Model], 1)
 | 
			
		||||
	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s, Model: %s", session.Model.Platform, apiURL, apiKey.Value, apiKey.ProxyURL, req.Model)
 | 
			
		||||
	wsURL, err := assembleAuthUrl(apiURL, key[1], key[2])
 | 
			
		||||
	//握手并建立websocket 连接
 | 
			
		||||
	conn, resp, err := d.Dial(wsURL, nil)
 | 
			
		||||
@@ -138,6 +153,10 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		content = result.Payload.Choices.Text[0].Content
 | 
			
		||||
		// 处理代码换行
 | 
			
		||||
		if len(content) == 0 {
 | 
			
		||||
			content = "\n"
 | 
			
		||||
		}
 | 
			
		||||
		contents = append(contents, content)
 | 
			
		||||
		// 第一个结果
 | 
			
		||||
		if result.Payload.Choices.Status == 0 {
 | 
			
		||||
@@ -165,9 +184,6 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
 | 
			
		||||
	// 消息发送成功
 | 
			
		||||
	if len(contents) > 0 {
 | 
			
		||||
		// 更新用户的对话次数
 | 
			
		||||
		h.subUserCalls(userVo, session)
 | 
			
		||||
 | 
			
		||||
		if message.Role == "" {
 | 
			
		||||
			message.Role = "assistant"
 | 
			
		||||
		}
 | 
			
		||||
@@ -175,63 +191,64 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
		useMsg := types.Message{Role: "user", Content: prompt}
 | 
			
		||||
 | 
			
		||||
		// 更新上下文消息,如果是调用函数则不需要更新上下文
 | 
			
		||||
		if h.App.ChatConfig.EnableContext {
 | 
			
		||||
		if h.App.SysConfig.EnableContext {
 | 
			
		||||
			chatCtx = append(chatCtx, useMsg)  // 提问消息
 | 
			
		||||
			chatCtx = append(chatCtx, message) // 回复消息
 | 
			
		||||
			h.App.ChatContexts.Put(session.ChatId, chatCtx)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 追加聊天记录
 | 
			
		||||
		if h.App.ChatConfig.EnableHistory {
 | 
			
		||||
			// for prompt
 | 
			
		||||
			promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg := model.HistoryMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.PromptMsg,
 | 
			
		||||
				Icon:       userVo.Avatar,
 | 
			
		||||
				Content:    prompt,
 | 
			
		||||
				Tokens:     promptToken,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
			}
 | 
			
		||||
			historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
			historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
			res := h.db.Save(&historyUserMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// for reply
 | 
			
		||||
			// 计算本次对话消耗的总 token 数量
 | 
			
		||||
			replyToken, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
			totalTokens := replyToken + getTotalTokens(req)
 | 
			
		||||
			historyReplyMsg := model.HistoryMessage{
 | 
			
		||||
				UserId:     userVo.Id,
 | 
			
		||||
				ChatId:     session.ChatId,
 | 
			
		||||
				RoleId:     role.Id,
 | 
			
		||||
				Type:       types.ReplyMsg,
 | 
			
		||||
				Icon:       role.Icon,
 | 
			
		||||
				Content:    message.Content,
 | 
			
		||||
				Tokens:     totalTokens,
 | 
			
		||||
				UseContext: true,
 | 
			
		||||
			}
 | 
			
		||||
			historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
			historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
			res = h.db.Create(&historyReplyMsg)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
			}
 | 
			
		||||
			// 更新用户信息
 | 
			
		||||
			h.incUserTokenFee(userVo.Id, totalTokens)
 | 
			
		||||
		// for prompt
 | 
			
		||||
		promptToken, err := utils.CalcTokens(prompt, req.Model)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error(err)
 | 
			
		||||
		}
 | 
			
		||||
		historyUserMsg := model.ChatMessage{
 | 
			
		||||
			UserId:     userVo.Id,
 | 
			
		||||
			ChatId:     session.ChatId,
 | 
			
		||||
			RoleId:     role.Id,
 | 
			
		||||
			Type:       types.PromptMsg,
 | 
			
		||||
			Icon:       userVo.Avatar,
 | 
			
		||||
			Content:    template.HTMLEscapeString(prompt),
 | 
			
		||||
			Tokens:     promptToken,
 | 
			
		||||
			UseContext: true,
 | 
			
		||||
			Model:      req.Model,
 | 
			
		||||
		}
 | 
			
		||||
		historyUserMsg.CreatedAt = promptCreatedAt
 | 
			
		||||
		historyUserMsg.UpdatedAt = promptCreatedAt
 | 
			
		||||
		res := h.DB.Save(&historyUserMsg)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("failed to save prompt history message: ", res.Error)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// for reply
 | 
			
		||||
		// 计算本次对话消耗的总 token 数量
 | 
			
		||||
		replyTokens, _ := utils.CalcTokens(message.Content, req.Model)
 | 
			
		||||
		totalTokens := replyTokens + getTotalTokens(req)
 | 
			
		||||
		historyReplyMsg := model.ChatMessage{
 | 
			
		||||
			UserId:     userVo.Id,
 | 
			
		||||
			ChatId:     session.ChatId,
 | 
			
		||||
			RoleId:     role.Id,
 | 
			
		||||
			Type:       types.ReplyMsg,
 | 
			
		||||
			Icon:       role.Icon,
 | 
			
		||||
			Content:    message.Content,
 | 
			
		||||
			Tokens:     totalTokens,
 | 
			
		||||
			UseContext: true,
 | 
			
		||||
			Model:      req.Model,
 | 
			
		||||
		}
 | 
			
		||||
		historyReplyMsg.CreatedAt = replyCreatedAt
 | 
			
		||||
		historyReplyMsg.UpdatedAt = replyCreatedAt
 | 
			
		||||
		res = h.DB.Create(&historyReplyMsg)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("failed to save reply history message: ", res.Error)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 更新用户算力
 | 
			
		||||
		h.subUserPower(userVo, session, promptToken, replyTokens)
 | 
			
		||||
 | 
			
		||||
		// 保存当前会话
 | 
			
		||||
		var chatItem model.ChatItem
 | 
			
		||||
		res := h.db.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
		res = h.DB.Where("chat_id = ?", session.ChatId).First(&chatItem)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			chatItem.ChatId = session.ChatId
 | 
			
		||||
			chatItem.UserId = session.UserId
 | 
			
		||||
@@ -242,7 +259,8 @@ func (h *ChatHandler) sendXunFeiMessage(
 | 
			
		||||
			} else {
 | 
			
		||||
				chatItem.Title = prompt
 | 
			
		||||
			}
 | 
			
		||||
			h.db.Create(&chatItem)
 | 
			
		||||
			chatItem.Model = req.Model
 | 
			
		||||
			h.DB.Create(&chatItem)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -258,7 +276,7 @@ func buildRequest(appid string, req types.ApiRequest) map[string]interface{} {
 | 
			
		||||
		"parameter": map[string]interface{}{
 | 
			
		||||
			"chat": map[string]interface{}{
 | 
			
		||||
				"domain":      req.Model,
 | 
			
		||||
				"temperature": float64(req.Temperature),
 | 
			
		||||
				"temperature": req.Temperature,
 | 
			
		||||
				"top_k":       int64(6),
 | 
			
		||||
				"max_tokens":  int64(req.MaxTokens),
 | 
			
		||||
				"auditing":    "default",
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										46
									
								
								api/handler/config_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										46
									
								
								api/handler/config_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,46 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ConfigHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewConfigHandler(app *core.AppServer, db *gorm.DB) *ConfigHandler {
 | 
			
		||||
	return &ConfigHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Get 获取指定的系统配置
 | 
			
		||||
func (h *ConfigHandler) Get(c *gin.Context) {
 | 
			
		||||
	key := c.Query("key")
 | 
			
		||||
	var config model.Config
 | 
			
		||||
	res := h.DB.Where("marker", key).First(&config)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var value map[string]interface{}
 | 
			
		||||
	err := utils.JsonDecode(config.Config, &value)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, value)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										261
									
								
								api/handler/dalle_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										261
									
								
								api/handler/dalle_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,261 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service/dalle"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"net/http"
 | 
			
		||||
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type DallJobHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis    *redis.Client
 | 
			
		||||
	service  *dalle.Service
 | 
			
		||||
	uploader *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewDallJobHandler(app *core.AppServer, db *gorm.DB, service *dalle.Service, manager *oss.UploaderManager) *DallJobHandler {
 | 
			
		||||
	return &DallJobHandler{
 | 
			
		||||
		service:  service,
 | 
			
		||||
		uploader: manager,
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *DallJobHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		logger.Info("Invalid user ID")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.service.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := client.Receive()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				client.Close()
 | 
			
		||||
				h.service.Clients.Delete(uint(userId))
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var message types.WsMessage
 | 
			
		||||
			err = utils.JsonDecode(string(msg), &message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 心跳消息
 | 
			
		||||
			if message.Type == "heartbeat" {
 | 
			
		||||
				logger.Debug("收到 DallE 心跳消息:", message.Content)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *DallJobHandler) preCheck(c *gin.Context) bool {
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
	if user.Power < h.App.SysConfig.DallPower {
 | 
			
		||||
		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return true
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Image 创建一个绘画任务
 | 
			
		||||
func (h *DallJobHandler) Image(c *gin.Context) {
 | 
			
		||||
	if !h.preCheck(c) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data types.DallTask
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil || data.Prompt == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
	job := model.DallJob{
 | 
			
		||||
		UserId: uint(userId),
 | 
			
		||||
		Prompt: data.Prompt,
 | 
			
		||||
		Power:  h.App.SysConfig.DallPower,
 | 
			
		||||
	}
 | 
			
		||||
	res := h.DB.Create(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "error with save job: "+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.service.PushTask(types.DallTask{
 | 
			
		||||
		JobId:   job.Id,
 | 
			
		||||
		UserId:  uint(userId),
 | 
			
		||||
		Prompt:  data.Prompt,
 | 
			
		||||
		Quality: data.Quality,
 | 
			
		||||
		Size:    data.Size,
 | 
			
		||||
		Style:   data.Style,
 | 
			
		||||
		Power:   job.Power,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	client := h.service.Clients.Get(job.UserId)
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ImgWall 照片墙
 | 
			
		||||
func (h *DallJobHandler) ImgWall(c *gin.Context) {
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	err, jobs := h.getData(true, 0, page, pageSize, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 SD 任务列表
 | 
			
		||||
func (h *DallJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetBool(c, "status")
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
 | 
			
		||||
	err, jobs := h.getData(status, userId, page, pageSize, publish)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取任务列表
 | 
			
		||||
func (h *DallJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.DallJob) {
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if finish {
 | 
			
		||||
		session = session.Where("progress = ?", 100).Order("id DESC")
 | 
			
		||||
	} else {
 | 
			
		||||
		session = session.Where("progress < ?", 100).Order("id ASC")
 | 
			
		||||
	}
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		session = session.Where("user_id = ?", userId)
 | 
			
		||||
	}
 | 
			
		||||
	if publish {
 | 
			
		||||
		session = session.Where("publish", publish)
 | 
			
		||||
	}
 | 
			
		||||
	if page > 0 && pageSize > 0 {
 | 
			
		||||
		offset := (page - 1) * pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var items []model.DallJob
 | 
			
		||||
	res := session.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return res.Error, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.DallJob, 0)
 | 
			
		||||
	for _, item := range items {
 | 
			
		||||
		var job vo.DallJob
 | 
			
		||||
		err := utils.CopyObject(item, &job)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		jobs = append(jobs, job)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil, jobs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove task image
 | 
			
		||||
func (h *DallJobHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint   `json:"id"`
 | 
			
		||||
		UserId uint   `json:"user_id"`
 | 
			
		||||
		ImgURL string `json:"img_url"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove job recode
 | 
			
		||||
	res := h.DB.Delete(&model.DallJob{Id: data.Id})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove image
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("remove image failed: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Publish 发布/取消发布图片到画廊显示
 | 
			
		||||
func (h *DallJobHandler) Publish(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint `json:"id"`
 | 
			
		||||
		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.DallJob{Id: data.Id}).UpdateColumn("publish", true)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										226
									
								
								api/handler/function_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										226
									
								
								api/handler/function_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,226 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service/dalle"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type FunctionHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	config        types.ApiConfig
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	dallService   *dalle.Service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFunctionHandler(
 | 
			
		||||
	server *core.AppServer,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	config *types.AppConfig,
 | 
			
		||||
	manager *oss.UploaderManager,
 | 
			
		||||
	dallService *dalle.Service) *FunctionHandler {
 | 
			
		||||
	return &FunctionHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: server,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
		config:        config.ApiConfig,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		dallService:   dallService,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type resVo struct {
 | 
			
		||||
	Code    types.BizCode `json:"code"`
 | 
			
		||||
	Message string        `json:"message"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		Title     string     `json:"title"`
 | 
			
		||||
		UpdatedAt string     `json:"updated_at"`
 | 
			
		||||
		Items     []dataItem `json:"items"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type dataItem struct {
 | 
			
		||||
	Title  string `json:"title"`
 | 
			
		||||
	Url    string `json:"url"`
 | 
			
		||||
	Remark string `json:"remark"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// check authorization
 | 
			
		||||
func (h *FunctionHandler) checkAuth(c *gin.Context) error {
 | 
			
		||||
	tokenString := c.GetHeader(types.UserAuthHeader)
 | 
			
		||||
	token, err := jwt.Parse(tokenString, func(token *jwt.Token) (interface{}, error) {
 | 
			
		||||
		if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
 | 
			
		||||
			return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		return []byte(h.App.Config.Session.SecretKey), nil
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with parse auth token: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	claims, ok := token.Claims.(jwt.MapClaims)
 | 
			
		||||
	if !ok || !token.Valid {
 | 
			
		||||
		return errors.New("token is invalid")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	expr := utils.IntValue(utils.InterfaceToString(claims["expired"]), 0)
 | 
			
		||||
	if expr > 0 && int64(expr) < time.Now().Unix() {
 | 
			
		||||
		return errors.New("token is expired")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// WeiBo 微博热搜
 | 
			
		||||
func (h *FunctionHandler) WeiBo(c *gin.Context) {
 | 
			
		||||
	if err := h.checkAuth(c); err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if h.config.Token == "" {
 | 
			
		||||
		resp.ERROR(c, "无效的 API Token")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/weibo/fetch", h.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("AppId", h.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		resp.ERROR(c, res.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
 | 
			
		||||
	for i, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, strings.Join(builder, "\n\n"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// ZaoBao 今日早报
 | 
			
		||||
func (h *FunctionHandler) ZaoBao(c *gin.Context) {
 | 
			
		||||
	if err := h.checkAuth(c); err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if h.config.Token == "" {
 | 
			
		||||
		resp.ERROR(c, "无效的 API Token")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/zaobao/fetch", h.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("AppId", h.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", h.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		resp.ERROR(c, fmt.Sprintf("%v%v", err, r.Err))
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		resp.ERROR(c, res.Message)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
 | 
			
		||||
	for _, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, v.Title)
 | 
			
		||||
	}
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
 | 
			
		||||
	resp.SUCCESS(c, strings.Join(builder, "\n\n"))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Dall3 DallE3 AI 绘图
 | 
			
		||||
func (h *FunctionHandler) Dall3(c *gin.Context) {
 | 
			
		||||
	if err := h.checkAuth(c); err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var params map[string]interface{}
 | 
			
		||||
	if err := c.ShouldBindJSON(¶ms); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("绘画参数:%+v", params)
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.DB.Where("id = ?", params["user_id"]).First(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "当前用户不存在!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Power < h.App.SysConfig.DallPower {
 | 
			
		||||
		resp.ERROR(c, "创建 DALL-E 绘图任务失败,算力不足")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// create dall task
 | 
			
		||||
	prompt := utils.InterfaceToString(params["prompt"])
 | 
			
		||||
	job := model.DallJob{
 | 
			
		||||
		UserId: user.Id,
 | 
			
		||||
		Prompt: prompt,
 | 
			
		||||
		Power:  h.App.SysConfig.DallPower,
 | 
			
		||||
	}
 | 
			
		||||
	res = h.DB.Create(&job)
 | 
			
		||||
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "创建 DALL-E 绘图任务失败:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	content, err := h.dallService.Image(types.DallTask{
 | 
			
		||||
		JobId:   job.Id,
 | 
			
		||||
		UserId:  user.Id,
 | 
			
		||||
		Prompt:  job.Prompt,
 | 
			
		||||
		N:       1,
 | 
			
		||||
		Quality: "standard",
 | 
			
		||||
		Size:    "1024x1024",
 | 
			
		||||
		Style:   "vivid",
 | 
			
		||||
		Power:   job.Power,
 | 
			
		||||
	}, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "任务执行失败:"+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, content)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										100
									
								
								api/handler/invite_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										100
									
								
								api/handler/invite_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,100 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// InviteHandler 用户邀请
 | 
			
		||||
type InviteHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewInviteHandler(app *core.AppServer, db *gorm.DB) *InviteHandler {
 | 
			
		||||
	return &InviteHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Code 获取当前用户邀请码
 | 
			
		||||
func (h *InviteHandler) Code(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var inviteCode model.InviteCode
 | 
			
		||||
	res := h.DB.Where("user_id = ?", userId).First(&inviteCode)
 | 
			
		||||
	// 如果邀请码不存在,则创建一个
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		code := strings.ToUpper(utils.RandString(8))
 | 
			
		||||
		for {
 | 
			
		||||
			res = h.DB.Where("code = ?", code).First(&inviteCode)
 | 
			
		||||
			if res.Error != nil { // 不存在相同的邀请码则退出
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		inviteCode.UserId = userId
 | 
			
		||||
		inviteCode.Code = code
 | 
			
		||||
		h.DB.Create(&inviteCode)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var codeVo vo.InviteCode
 | 
			
		||||
	err := utils.CopyObject(inviteCode, &codeVo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "拷贝对象失败")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, codeVo)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List Log 用户邀请记录
 | 
			
		||||
func (h *InviteHandler) List(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Page     int `json:"page"`
 | 
			
		||||
		PageSize int `json:"page_size"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("inviter_id = ?", userId)
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.InviteLog{}).Count(&total)
 | 
			
		||||
	var items []model.InviteLog
 | 
			
		||||
	var list = make([]vo.InviteLog, 0)
 | 
			
		||||
	offset := (data.Page - 1) * data.PageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var v vo.InviteLog
 | 
			
		||||
			err := utils.CopyObject(item, &v)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				v.Id = item.Id
 | 
			
		||||
				v.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
				list = append(list, v)
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Hits 访问邀请码
 | 
			
		||||
func (h *InviteHandler) Hits(c *gin.Context) {
 | 
			
		||||
	code := c.Query("code")
 | 
			
		||||
	h.DB.Model(&model.InviteCode{}).Where("code = ?", code).UpdateColumn("hits", gorm.Expr("hits + ?", 1))
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										273
									
								
								api/handler/markmap_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										273
									
								
								api/handler/markmap_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,273 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bufio"
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"encoding/json"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MarkMapHandler 生成思维导图
 | 
			
		||||
type MarkMapHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	clients *types.LMap[int, *types.WsClient]
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMarkMapHandler(app *core.AppServer, db *gorm.DB) *MarkMapHandler {
 | 
			
		||||
	return &MarkMapHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{App: app, DB: db},
 | 
			
		||||
		clients:     types.NewLMap[int, *types.WsClient](),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MarkMapHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	modelId := h.GetInt(c, "model_id", 0)
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.clients.Put(userId, client)
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			_, msg, err := client.Receive()
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				client.Close()
 | 
			
		||||
				h.clients.Delete(userId)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var message types.WsMessage
 | 
			
		||||
			err = utils.JsonDecode(string(msg), &message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 心跳消息
 | 
			
		||||
			if message.Type == "heartbeat" {
 | 
			
		||||
				logger.Debug("收到 MarkMap 心跳消息:", message.Content)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			// change model
 | 
			
		||||
			if message.Type == "model_id" {
 | 
			
		||||
				modelId = utils.IntValue(utils.InterfaceToString(message.Content), 0)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			logger.Info("Receive a message: ", message.Content)
 | 
			
		||||
			err = h.sendMessage(client, utils.InterfaceToString(message.Content), modelId, userId)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsErr, Content: err.Error()})
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MarkMapHandler) sendMessage(client *types.WsClient, prompt string, modelId int, userId int) error {
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.DB.Model(&model.User{}).First(&user, userId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return fmt.Errorf("error with query user info: %v", res.Error)
 | 
			
		||||
	}
 | 
			
		||||
	var chatModel model.ChatModel
 | 
			
		||||
	res = h.DB.Where("id", modelId).First(&chatModel)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return fmt.Errorf("error with query chat model: %v", res.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Status == false {
 | 
			
		||||
		return errors.New("当前用户被禁用")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Power < chatModel.Power {
 | 
			
		||||
		return fmt.Errorf("您当前剩余算力(%d)已不足以支付当前模型算力(%d)!", user.Power, chatModel.Power)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	messages := make([]interface{}, 0)
 | 
			
		||||
	messages = append(messages, types.Message{Role: "system", Content: `
 | 
			
		||||
你是一位非常优秀的思维导图助手,你会把用户的所有提问都总结成思维导图,然后以 Markdown 格式输出。markdown 只需要输出一级标题,二级标题,三级标题,四级标题,最多输出四级,除此之外不要输出任何其他 markdown 标记。下面是一个合格的例子:
 | 
			
		||||
# Geek-AI 助手
 | 
			
		||||
 | 
			
		||||
## 完整的开源系统
 | 
			
		||||
### 前端开源
 | 
			
		||||
### 后端开源
 | 
			
		||||
 | 
			
		||||
## 支持各种大模型
 | 
			
		||||
### OpenAI 
 | 
			
		||||
### Azure 
 | 
			
		||||
### 文心一言
 | 
			
		||||
### 通义千问
 | 
			
		||||
 | 
			
		||||
## 集成多种收费方式
 | 
			
		||||
### 支付宝
 | 
			
		||||
### 微信
 | 
			
		||||
 | 
			
		||||
另外,除此之外不要任何解释性语句。
 | 
			
		||||
`})
 | 
			
		||||
	messages = append(messages, types.Message{Role: "user", Content: prompt})
 | 
			
		||||
	var req = types.ApiRequest{
 | 
			
		||||
		Model:    chatModel.Value,
 | 
			
		||||
		Stream:   true,
 | 
			
		||||
		Messages: messages,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	response, err := h.doRequest(req, chatModel, &apiKey)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("请求 OpenAI API 失败: %s", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer response.Body.Close()
 | 
			
		||||
 | 
			
		||||
	contentType := response.Header.Get("Content-Type")
 | 
			
		||||
	if strings.Contains(contentType, "text/event-stream") {
 | 
			
		||||
		// 循环读取 Chunk 消息
 | 
			
		||||
		scanner := bufio.NewScanner(response.Body)
 | 
			
		||||
		var isNew = true
 | 
			
		||||
		for scanner.Scan() {
 | 
			
		||||
			line := scanner.Text()
 | 
			
		||||
			if !strings.Contains(line, "data:") || len(line) < 30 {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			var responseBody = types.ApiResponse{}
 | 
			
		||||
			err = json.Unmarshal([]byte(line[6:]), &responseBody)
 | 
			
		||||
			if err != nil { // 数据解析出错
 | 
			
		||||
				return fmt.Errorf("error with decode data: %v", line)
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if len(responseBody.Choices) == 0 { // Fixed: 兼容 Azure API 第一个输出空行
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if responseBody.Choices[0].FinishReason == "stop" {
 | 
			
		||||
				break
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			if isNew {
 | 
			
		||||
				utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsStart})
 | 
			
		||||
				isNew = false
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(client, types.WsMessage{
 | 
			
		||||
				Type:    types.WsMiddle,
 | 
			
		||||
				Content: utils.InterfaceToString(responseBody.Choices[0].Delta.Content),
 | 
			
		||||
			})
 | 
			
		||||
		} // end for
 | 
			
		||||
 | 
			
		||||
		utils.ReplyChunkMessage(client, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
 | 
			
		||||
	} else {
 | 
			
		||||
		body, err := io.ReadAll(response.Body)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("读取响应失败: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		var res types.ApiError
 | 
			
		||||
		err = json.Unmarshal(body, &res)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return fmt.Errorf("解析响应失败: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// OpenAI API 调用异常处理
 | 
			
		||||
		if strings.Contains(res.Error.Message, "This key is associated with a deactivated account") {
 | 
			
		||||
			// remove key
 | 
			
		||||
			h.DB.Where("value = ?", apiKey).Delete(&model.ApiKey{})
 | 
			
		||||
			return errors.New("请求 OpenAI API 失败:API KEY 所关联的账户被禁用。")
 | 
			
		||||
		} else if strings.Contains(res.Error.Message, "You exceeded your current quota") {
 | 
			
		||||
			return errors.New("请求 OpenAI API 失败:API KEY 触发并发限制,请稍后再试。")
 | 
			
		||||
		} else {
 | 
			
		||||
			return fmt.Errorf("请求 OpenAI API 失败:%v", res.Error.Message)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 扣减算力
 | 
			
		||||
	res = h.DB.Model(&model.User{}).Where("id", userId).UpdateColumn("power", gorm.Expr("power - ?", chatModel.Power))
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		// 记录算力消费日志
 | 
			
		||||
		var u model.User
 | 
			
		||||
		h.DB.Where("id", userId).First(&u)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    u.Id,
 | 
			
		||||
			Username:  u.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    chatModel.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Balance:   u.Power,
 | 
			
		||||
			Model:     chatModel.Value,
 | 
			
		||||
			Remark:    fmt.Sprintf("AI绘制思维导图,模型名称:%s, ", chatModel.Value),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MarkMapHandler) doRequest(req types.ApiRequest, chatModel model.ChatModel, apiKey *model.ApiKey) (*http.Response, error) {
 | 
			
		||||
	// if the chat model bind a KEY, use it directly
 | 
			
		||||
	var res *gorm.DB
 | 
			
		||||
	if chatModel.KeyId > 0 {
 | 
			
		||||
		res = h.DB.Where("id", chatModel.KeyId).Where("enabled", true).Find(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	// use the last unused key
 | 
			
		||||
	if apiKey.Id == 0 {
 | 
			
		||||
		res = h.DB.Where("platform", types.OpenAI).
 | 
			
		||||
			Where("type", "chat").
 | 
			
		||||
			Where("enabled", true).Order("last_used_at ASC").First(apiKey)
 | 
			
		||||
	}
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return nil, errors.New("no available key, please import key")
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := apiKey.ApiURL
 | 
			
		||||
	// 更新 API KEY 的最后使用时间
 | 
			
		||||
	h.DB.Model(apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
 | 
			
		||||
	// 创建 HttpClient 请求对象
 | 
			
		||||
	var client *http.Client
 | 
			
		||||
	requestBody, err := json.Marshal(req)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
	request, err := http.NewRequest(http.MethodPost, apiURL, bytes.NewBuffer(requestBody))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	request.Header.Set("Content-Type", "application/json")
 | 
			
		||||
	if len(apiKey.ProxyURL) > 5 { // 使用代理
 | 
			
		||||
		proxy, _ := url.Parse(apiKey.ProxyURL)
 | 
			
		||||
		client = &http.Client{
 | 
			
		||||
			Transport: &http.Transport{
 | 
			
		||||
				Proxy: http.ProxyURL(proxy),
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
	} else {
 | 
			
		||||
		client = http.DefaultClient
 | 
			
		||||
	}
 | 
			
		||||
	request.Header.Set("Authorization", fmt.Sprintf("Bearer %s", apiKey.Value))
 | 
			
		||||
	return client.Do(request)
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										43
									
								
								api/handler/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										43
									
								
								api/handler/menu_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,43 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MenuHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMenuHandler(app *core.AppServer, db *gorm.DB) *MenuHandler {
 | 
			
		||||
	return &MenuHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 数据列表
 | 
			
		||||
func (h *MenuHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.Menu
 | 
			
		||||
	var list = make([]vo.Menu, 0)
 | 
			
		||||
	res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var product vo.Menu
 | 
			
		||||
			err := utils.CopyObject(item, &product)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				list = append(list, product)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, list)
 | 
			
		||||
}
 | 
			
		||||
@@ -1,68 +1,66 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/mj"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MidJourneyHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	db        *gorm.DB
 | 
			
		||||
	mjService *mj.Service
 | 
			
		||||
	pool      *mj.ServicePool
 | 
			
		||||
	snowflake *service.Snowflake
 | 
			
		||||
	uploader  *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMidJourneyHandler(
 | 
			
		||||
	app *core.AppServer,
 | 
			
		||||
	client *redis.Client,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	mjService *mj.Service) *MidJourneyHandler {
 | 
			
		||||
	h := MidJourneyHandler{
 | 
			
		||||
		redis:     client,
 | 
			
		||||
		db:        db,
 | 
			
		||||
		mjService: mjService,
 | 
			
		||||
func NewMidJourneyHandler(app *core.AppServer, db *gorm.DB, snowflake *service.Snowflake, pool *mj.ServicePool, manager *oss.UploaderManager) *MidJourneyHandler {
 | 
			
		||||
	return &MidJourneyHandler{
 | 
			
		||||
		snowflake: snowflake,
 | 
			
		||||
		pool:      pool,
 | 
			
		||||
		uploader:  manager,
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionId := c.Query("session_id")
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.mjService.Clients.Put(sessionId, client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
func (h *MidJourneyHandler) preCheck(c *gin.Context) bool {
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.ImgCalls <= 0 {
 | 
			
		||||
		resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
 | 
			
		||||
	if user.Power < h.App.SysConfig.MjPower {
 | 
			
		||||
		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.pool.HasAvailableService() {
 | 
			
		||||
		resp.ERROR(c, "MidJourney 池子中没有没有可用的服务!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -70,97 +68,180 @@ func (h *MidJourneyHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Image 创建一个绘画任务
 | 
			
		||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
	if !h.App.Config.MjConfig.Enabled {
 | 
			
		||||
		resp.ERROR(c, "MidJourney service is disabled")
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
func (h *MidJourneyHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		logger.Info("Invalid user ID")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	h.pool.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Image 创建一个绘画任务
 | 
			
		||||
func (h *MidJourneyHandler) Image(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		SessionId string  `json:"session_id"`
 | 
			
		||||
		Prompt    string  `json:"prompt"`
 | 
			
		||||
		Rate      string  `json:"rate"`
 | 
			
		||||
		Model     string  `json:"model"`
 | 
			
		||||
		Chaos     int     `json:"chaos"`
 | 
			
		||||
		Raw       bool    `json:"raw"`
 | 
			
		||||
		Seed      int64   `json:"seed"`
 | 
			
		||||
		Stylize   int     `json:"stylize"`
 | 
			
		||||
		Img       string  `json:"img"`
 | 
			
		||||
		Weight    float32 `json:"weight"`
 | 
			
		||||
		SessionId string   `json:"session_id"`
 | 
			
		||||
		TaskType  string   `json:"task_type"`
 | 
			
		||||
		Prompt    string   `json:"prompt"`
 | 
			
		||||
		NegPrompt string   `json:"neg_prompt"`
 | 
			
		||||
		Rate      string   `json:"rate"`
 | 
			
		||||
		Model     string   `json:"model"`
 | 
			
		||||
		Chaos     int      `json:"chaos"`
 | 
			
		||||
		Raw       bool     `json:"raw"`
 | 
			
		||||
		Seed      int64    `json:"seed"`
 | 
			
		||||
		Stylize   int      `json:"stylize"`
 | 
			
		||||
		ImgArr    []string `json:"img_arr"`
 | 
			
		||||
		Tile      bool     `json:"tile"`
 | 
			
		||||
		Quality   float32  `json:"quality"`
 | 
			
		||||
		Iw        float32  `json:"iw"`
 | 
			
		||||
		CRef      string   `json:"cref"` //生成角色一致的图像
 | 
			
		||||
		SRef      string   `json:"sref"` //生成风格一致的图像
 | 
			
		||||
		Cw        int      `json:"cw"`   // 参考程度
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if !h.checkLimits(c) {
 | 
			
		||||
	if !h.preCheck(c) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var prompt = data.Prompt
 | 
			
		||||
	if data.Rate != "" && !strings.Contains(prompt, "--ar") {
 | 
			
		||||
		prompt += " --ar " + data.Rate
 | 
			
		||||
	var params = ""
 | 
			
		||||
	if data.Rate != "" && !strings.Contains(params, "--ar") {
 | 
			
		||||
		params += " --ar " + data.Rate
 | 
			
		||||
	}
 | 
			
		||||
	if data.Seed > 0 && !strings.Contains(prompt, "--seed") {
 | 
			
		||||
		prompt += fmt.Sprintf(" --seed %d", data.Seed)
 | 
			
		||||
	if data.Seed > 0 && !strings.Contains(params, "--seed") {
 | 
			
		||||
		params += fmt.Sprintf(" --seed %d", data.Seed)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Stylize > 0 && !strings.Contains(prompt, "--s") && !strings.Contains(prompt, "--stylize") {
 | 
			
		||||
		prompt += fmt.Sprintf(" --s %d", data.Stylize)
 | 
			
		||||
	if data.Stylize > 0 && !strings.Contains(params, "--s") && !strings.Contains(params, "--stylize") {
 | 
			
		||||
		params += fmt.Sprintf(" --s %d", data.Stylize)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Chaos > 0 && !strings.Contains(prompt, "--c") && !strings.Contains(prompt, "--chaos") {
 | 
			
		||||
		prompt += fmt.Sprintf(" --c %d", data.Chaos)
 | 
			
		||||
	if data.Chaos > 0 && !strings.Contains(params, "--c") && !strings.Contains(params, "--chaos") {
 | 
			
		||||
		params += fmt.Sprintf(" --c %d", data.Chaos)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Img != "" {
 | 
			
		||||
		prompt = fmt.Sprintf("%s %s", data.Img, prompt)
 | 
			
		||||
		if data.Weight > 0 {
 | 
			
		||||
			prompt += fmt.Sprintf(" --iw %f", data.Weight)
 | 
			
		||||
		}
 | 
			
		||||
	if len(data.ImgArr) > 0 && data.Iw > 0 {
 | 
			
		||||
		params += fmt.Sprintf(" --iw %.2f", data.Iw)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Raw {
 | 
			
		||||
		prompt += " --style raw"
 | 
			
		||||
		params += " --style raw"
 | 
			
		||||
	}
 | 
			
		||||
	if data.Model != "" && !strings.Contains(prompt, "--v") && !strings.Contains(prompt, "--niji") {
 | 
			
		||||
		prompt += data.Model
 | 
			
		||||
	if data.Quality > 0 {
 | 
			
		||||
		params += fmt.Sprintf(" --q %.2f", data.Quality)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Tile {
 | 
			
		||||
		params += " --tile "
 | 
			
		||||
	}
 | 
			
		||||
	if data.CRef != "" {
 | 
			
		||||
		params += fmt.Sprintf(" --cref %s", data.CRef)
 | 
			
		||||
		if data.Cw > 0 {
 | 
			
		||||
			params += fmt.Sprintf(" --cw %d", data.Cw)
 | 
			
		||||
		} else {
 | 
			
		||||
			params += " --cw 100"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if data.SRef != "" {
 | 
			
		||||
		params += fmt.Sprintf(" --sref %s", data.SRef)
 | 
			
		||||
	}
 | 
			
		||||
	if data.Model != "" && !strings.Contains(params, "--v") && !strings.Contains(params, "--niji") {
 | 
			
		||||
		params += fmt.Sprintf(" %s", data.Model)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 处理融图和换脸的提示词
 | 
			
		||||
	if data.TaskType == types.TaskSwapFace.String() || data.TaskType == types.TaskBlend.String() {
 | 
			
		||||
		params = fmt.Sprintf("%s:%s", data.TaskType, strings.Join(data.ImgArr, ","))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 如果本地图片上传的是相对地址,处理成绝对地址
 | 
			
		||||
	for k, v := range data.ImgArr {
 | 
			
		||||
		if !strings.HasPrefix(v, "http") {
 | 
			
		||||
			data.ImgArr[k] = fmt.Sprintf("http://localhost:5678/%s", strings.TrimLeft(v, "/"))
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
	// generate task id
 | 
			
		||||
	taskId, err := h.snowflake.Next(true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with generate task id: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	job := model.MidJourneyJob{
 | 
			
		||||
		Type:      types.TaskImage.String(),
 | 
			
		||||
		Type:      data.TaskType,
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		TaskId:    taskId,
 | 
			
		||||
		Progress:  0,
 | 
			
		||||
		Prompt:    prompt,
 | 
			
		||||
		Prompt:    fmt.Sprintf("%s %s", data.Prompt, params),
 | 
			
		||||
		Power:     h.App.SysConfig.MjPower,
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
	if res := h.db.Create(&job); res.Error != nil {
 | 
			
		||||
	opt := "绘图"
 | 
			
		||||
	if data.TaskType == types.TaskBlend.String() {
 | 
			
		||||
		job.Prompt = "融图:" + strings.Join(data.ImgArr, ",")
 | 
			
		||||
		opt = "融图"
 | 
			
		||||
	} else if data.TaskType == types.TaskSwapFace.String() {
 | 
			
		||||
		job.Prompt = "换脸:" + strings.Join(data.ImgArr, ",")
 | 
			
		||||
		opt = "换脸"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.mjService.PushTask(types.MjTask{
 | 
			
		||||
		Id:        int(job.Id),
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
		Id:        job.Id,
 | 
			
		||||
		TaskId:    taskId,
 | 
			
		||||
		SessionId: data.SessionId,
 | 
			
		||||
		Src:       types.TaskSrcImg,
 | 
			
		||||
		Type:      types.TaskImage,
 | 
			
		||||
		Prompt:    prompt,
 | 
			
		||||
		Type:      types.TaskType(data.TaskType),
 | 
			
		||||
		Prompt:    data.Prompt,
 | 
			
		||||
		NegPrompt: data.NegPrompt,
 | 
			
		||||
		Params:    params,
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		ImgArr:    data.ImgArr,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	var jobVo vo.MidJourneyJob
 | 
			
		||||
	err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		// 推送任务到前端
 | 
			
		||||
		client := h.mjService.Clients.Get(data.SessionId)
 | 
			
		||||
		if client != nil {
 | 
			
		||||
			utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
		}
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "mid-journey",
 | 
			
		||||
			Remark:    fmt.Sprintf("%s操作,任务ID:%s", opt, job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type reqVo struct {
 | 
			
		||||
	Src         string `json:"src"`
 | 
			
		||||
	Index       int    `json:"index"`
 | 
			
		||||
	ChannelId   string `json:"channel_id"`
 | 
			
		||||
	MessageId   string `json:"message_id"`
 | 
			
		||||
	MessageHash string `json:"message_hash"`
 | 
			
		||||
	SessionId   string `json:"session_id"`
 | 
			
		||||
@@ -178,64 +259,60 @@ func (h *MidJourneyHandler) Upscale(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.checkLimits(c) {
 | 
			
		||||
	if !h.preCheck(c) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	jobId := 0
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
	src := types.TaskSrc(data.Src)
 | 
			
		||||
	if src == types.TaskSrcImg {
 | 
			
		||||
		job := model.MidJourneyJob{
 | 
			
		||||
			Type:      types.TaskUpscale.String(),
 | 
			
		||||
			UserId:    userId,
 | 
			
		||||
			Hash:      data.MessageHash,
 | 
			
		||||
			Progress:  0,
 | 
			
		||||
			Prompt:    data.Prompt,
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		}
 | 
			
		||||
		if res := h.db.Create(&job); res.Error == nil {
 | 
			
		||||
			jobId = int(job.Id)
 | 
			
		||||
		} else {
 | 
			
		||||
			resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var jobVo vo.MidJourneyJob
 | 
			
		||||
		err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			// 推送任务到前端
 | 
			
		||||
			client := h.mjService.Clients.Get(data.SessionId)
 | 
			
		||||
			if client != nil {
 | 
			
		||||
				utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	taskId, _ := h.snowflake.Next(true)
 | 
			
		||||
	job := model.MidJourneyJob{
 | 
			
		||||
		Type:        types.TaskUpscale.String(),
 | 
			
		||||
		ReferenceId: data.MessageId,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		TaskId:      taskId,
 | 
			
		||||
		Progress:    0,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		Power:       h.App.SysConfig.MjActionPower,
 | 
			
		||||
		CreatedAt:   time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
	h.mjService.PushTask(types.MjTask{
 | 
			
		||||
		Id:          jobId,
 | 
			
		||||
	if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
		Id:          job.Id,
 | 
			
		||||
		SessionId:   data.SessionId,
 | 
			
		||||
		Src:         src,
 | 
			
		||||
		Type:        types.TaskUpscale,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		RoleId:      data.RoleId,
 | 
			
		||||
		Icon:        data.Icon,
 | 
			
		||||
		ChatId:      data.ChatId,
 | 
			
		||||
		ChannelId:   data.ChannelId,
 | 
			
		||||
		Index:       data.Index,
 | 
			
		||||
		MessageId:   data.MessageId,
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if src == types.TaskSrcChat {
 | 
			
		||||
		wsClient := h.App.ChatClients.Get(data.SessionId)
 | 
			
		||||
		if wsClient != nil {
 | 
			
		||||
			content := fmt.Sprintf("**%s** 已推送 upscale 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
			utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			if h.mjService.ChatClients.Get(data.SessionId) == nil {
 | 
			
		||||
				h.mjService.ChatClients.Put(data.SessionId, wsClient)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "mid-journey",
 | 
			
		||||
			Remark:    fmt.Sprintf("Upscale 操作,任务ID:%s", job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
@@ -248,79 +325,100 @@ func (h *MidJourneyHandler) Variation(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.checkLimits(c) {
 | 
			
		||||
	if !h.preCheck(c) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	jobId := 0
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
	src := types.TaskSrc(data.Src)
 | 
			
		||||
	if src == types.TaskSrcImg {
 | 
			
		||||
		job := model.MidJourneyJob{
 | 
			
		||||
			Type:      types.TaskVariation.String(),
 | 
			
		||||
			UserId:    userId,
 | 
			
		||||
			ImgURL:    "",
 | 
			
		||||
			Hash:      data.MessageHash,
 | 
			
		||||
			Progress:  0,
 | 
			
		||||
			Prompt:    data.Prompt,
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		}
 | 
			
		||||
		if res := h.db.Create(&job); res.Error == nil {
 | 
			
		||||
			jobId = int(job.Id)
 | 
			
		||||
		} else {
 | 
			
		||||
			resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var jobVo vo.MidJourneyJob
 | 
			
		||||
		err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			// 推送任务到前端
 | 
			
		||||
			client := h.mjService.Clients.Get(data.SessionId)
 | 
			
		||||
			if client != nil {
 | 
			
		||||
				utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	taskId, _ := h.snowflake.Next(true)
 | 
			
		||||
	job := model.MidJourneyJob{
 | 
			
		||||
		Type:        types.TaskVariation.String(),
 | 
			
		||||
		ChannelId:   data.ChannelId,
 | 
			
		||||
		ReferenceId: data.MessageId,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		TaskId:      taskId,
 | 
			
		||||
		Progress:    0,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		Power:       h.App.SysConfig.MjActionPower,
 | 
			
		||||
		CreatedAt:   time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
	h.mjService.PushTask(types.MjTask{
 | 
			
		||||
		Id:          jobId,
 | 
			
		||||
	if res := h.DB.Create(&job); res.Error != nil || res.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "添加任务失败:"+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.pool.PushTask(types.MjTask{
 | 
			
		||||
		Id:          job.Id,
 | 
			
		||||
		SessionId:   data.SessionId,
 | 
			
		||||
		Src:         src,
 | 
			
		||||
		Type:        types.TaskVariation,
 | 
			
		||||
		Prompt:      data.Prompt,
 | 
			
		||||
		UserId:      userId,
 | 
			
		||||
		RoleId:      data.RoleId,
 | 
			
		||||
		Icon:        data.Icon,
 | 
			
		||||
		ChatId:      data.ChatId,
 | 
			
		||||
		Index:       data.Index,
 | 
			
		||||
		ChannelId:   data.ChannelId,
 | 
			
		||||
		MessageId:   data.MessageId,
 | 
			
		||||
		MessageHash: data.MessageHash,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	if src == types.TaskSrcChat {
 | 
			
		||||
		// 从聊天窗口发送的请求,记录客户端信息
 | 
			
		||||
		wsClient := h.mjService.ChatClients.Get(data.SessionId)
 | 
			
		||||
		if wsClient != nil {
 | 
			
		||||
			content := fmt.Sprintf("**%s** 已推送 variation 任务到 MidJourney 机器人,请耐心等待任务执行...", data.Prompt)
 | 
			
		||||
			utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			if h.mjService.Clients.Get(data.SessionId) == nil {
 | 
			
		||||
				h.mjService.Clients.Put(data.SessionId, wsClient)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "mid-journey",
 | 
			
		||||
			Remark:    fmt.Sprintf("Variation 操作,任务ID:%s", job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetInt(c, "status", 0)
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
// ImgWall 照片墙
 | 
			
		||||
func (h *MidJourneyHandler) ImgWall(c *gin.Context) {
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	err, jobs := h.getData(true, 0, page, pageSize, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.db.Session(&gorm.Session{})
 | 
			
		||||
	if status == 1 {
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetBool(c, "status")
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
 | 
			
		||||
	err, jobs := h.getData(status, userId, page, pageSize, publish)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *MidJourneyHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.MidJourneyJob) {
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if finish {
 | 
			
		||||
		session = session.Where("progress = ?", 100).Order("id DESC")
 | 
			
		||||
	} else {
 | 
			
		||||
		session = session.Where("progress < ?", 100).Order("id ASC")
 | 
			
		||||
@@ -328,6 +426,9 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		session = session.Where("user_id = ?", userId)
 | 
			
		||||
	}
 | 
			
		||||
	if publish {
 | 
			
		||||
		session = session.Where("publish = ?", publish)
 | 
			
		||||
	}
 | 
			
		||||
	if page > 0 && pageSize > 0 {
 | 
			
		||||
		offset := (page - 1) * pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
@@ -336,8 +437,7 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
			
		||||
	var items []model.MidJourneyJob
 | 
			
		||||
	res := session.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, types.NoData)
 | 
			
		||||
		return
 | 
			
		||||
		return res.Error, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.MidJourneyJob, 0)
 | 
			
		||||
@@ -347,20 +447,73 @@ func (h *MidJourneyHandler) JobList(c *gin.Context) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		if item.Progress < 100 {
 | 
			
		||||
			// 30 分钟还没完成的任务直接删除
 | 
			
		||||
			if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
 | 
			
		||||
				h.db.Delete(&item)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			if item.ImgURL != "" { // 正在运行中任务使用代理访问图片
 | 
			
		||||
				image, err := utils.DownloadImage(item.ImgURL, h.App.Config.ProxyURL)
 | 
			
		||||
 | 
			
		||||
		if item.Progress < 100 && item.ImgURL == "" && item.OrgURL != "" {
 | 
			
		||||
			// discord 服务器图片需要使用代理转发图片数据流
 | 
			
		||||
			if strings.HasPrefix(item.OrgURL, "https://cdn.discordapp.com") {
 | 
			
		||||
				image, err := utils.DownloadImage(item.OrgURL, h.App.Config.ProxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					job.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			} else {
 | 
			
		||||
				job.ImgURL = job.OrgURL
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		jobs = append(jobs, job)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
	return nil, jobs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove task image
 | 
			
		||||
func (h *MidJourneyHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint   `json:"id"`
 | 
			
		||||
		UserId uint   `json:"user_id"`
 | 
			
		||||
		ImgURL string `json:"img_url"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove job recode
 | 
			
		||||
	res := h.DB.Delete(&model.MidJourneyJob{Id: data.Id})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove image
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("remove image failed: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(data.UserId)
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Publish 发布图片到画廊显示
 | 
			
		||||
func (h *MidJourneyHandler) Publish(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint `json:"id"`
 | 
			
		||||
		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.MidJourneyJob{Id: data.Id}).UpdateColumn("publish", data.Action)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,25 +1,30 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type OrderHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewOrderHandler(app *core.AppServer, db *gorm.DB) *OrderHandler {
 | 
			
		||||
	h := OrderHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &OrderHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *OrderHandler) List(c *gin.Context) {
 | 
			
		||||
@@ -31,8 +36,8 @@ func (h *OrderHandler) List(c *gin.Context) {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	user, _ := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	session := h.db.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", user.Id, types.OrderPaidSuccess)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{}).Where("user_id = ? AND status = ?", userId, types.OrderPaidSuccess)
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.Order{}).Count(&total)
 | 
			
		||||
	var items []model.Order
 | 
			
		||||
 
 | 
			
		||||
@@ -1,80 +1,132 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/payment"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/payment"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"embed"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"github.com/shopspring/decimal"
 | 
			
		||||
	"math"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	PayWayAlipay = "支付宝"
 | 
			
		||||
	PayWayWechat = "微信支付"
 | 
			
		||||
	PayWayXunHu  = "虎皮椒"
 | 
			
		||||
	PayWayJs     = "PayJS"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PaymentHandler 支付服务回调 handler
 | 
			
		||||
type PaymentHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	alipayService *payment.AlipayService
 | 
			
		||||
	snowflake     *service.Snowflake
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	fs            embed.FS
 | 
			
		||||
	lock          sync.Mutex
 | 
			
		||||
	alipayService  *payment.AlipayService
 | 
			
		||||
	huPiPayService *payment.HuPiPayService
 | 
			
		||||
	js             *payment.PayJS
 | 
			
		||||
	snowflake      *service.Snowflake
 | 
			
		||||
	fs             embed.FS
 | 
			
		||||
	lock           sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPaymentHandler(server *core.AppServer, alipayService *payment.AlipayService, snowflake *service.Snowflake, db *gorm.DB, fs embed.FS) *PaymentHandler {
 | 
			
		||||
	h := PaymentHandler{lock: sync.Mutex{}}
 | 
			
		||||
	h.App = server
 | 
			
		||||
	h.alipayService = alipayService
 | 
			
		||||
	h.snowflake = snowflake
 | 
			
		||||
	h.db = db
 | 
			
		||||
	h.fs = fs
 | 
			
		||||
	return &h
 | 
			
		||||
func NewPaymentHandler(
 | 
			
		||||
	server *core.AppServer,
 | 
			
		||||
	alipayService *payment.AlipayService,
 | 
			
		||||
	huPiPayService *payment.HuPiPayService,
 | 
			
		||||
	js *payment.PayJS,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	snowflake *service.Snowflake,
 | 
			
		||||
	fs embed.FS) *PaymentHandler {
 | 
			
		||||
	return &PaymentHandler{
 | 
			
		||||
		alipayService:  alipayService,
 | 
			
		||||
		huPiPayService: huPiPayService,
 | 
			
		||||
		js:             js,
 | 
			
		||||
		snowflake:      snowflake,
 | 
			
		||||
		fs:             fs,
 | 
			
		||||
		lock:           sync.Mutex{},
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: server,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *PaymentHandler) Alipay(c *gin.Context) {
 | 
			
		||||
func (h *PaymentHandler) DoPay(c *gin.Context) {
 | 
			
		||||
	orderNo := h.GetTrim(c, "order_no")
 | 
			
		||||
	payWay := h.GetTrim(c, "pay_way")
 | 
			
		||||
 | 
			
		||||
	if orderNo == "" {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var order model.Order
 | 
			
		||||
	res := h.db.Where("order_no = ?", orderNo).First(&order)
 | 
			
		||||
	res := h.DB.Where("order_no = ?", orderNo).First(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Order not found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新扫码状态
 | 
			
		||||
	h.db.Model(&order).UpdateColumn("status", types.OrderScanned)
 | 
			
		||||
	// 生成支付链接
 | 
			
		||||
	notifyURL := h.App.Config.AlipayConfig.NotifyURL
 | 
			
		||||
	returnURL := "" // 关闭同步回跳
 | 
			
		||||
	amount := fmt.Sprintf("%.2f", order.Amount)
 | 
			
		||||
 | 
			
		||||
	uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with generate pay url: "+err.Error())
 | 
			
		||||
	// fix: 这里先检查一下订单状态,如果已经支付了,就直接返回
 | 
			
		||||
	if order.Status == types.OrderPaidSuccess {
 | 
			
		||||
		resp.ERROR(c, "This order had been paid, please do not pay twice")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.Redirect(302, uri)
 | 
			
		||||
	// 更新扫码状态
 | 
			
		||||
	h.DB.Model(&order).UpdateColumn("status", types.OrderScanned)
 | 
			
		||||
	if payWay == "alipay" { // 支付宝
 | 
			
		||||
		// 生成支付链接
 | 
			
		||||
		notifyURL := h.App.Config.AlipayConfig.NotifyURL
 | 
			
		||||
		returnURL := "" // 关闭同步回跳
 | 
			
		||||
		amount := fmt.Sprintf("%.2f", order.Amount)
 | 
			
		||||
 | 
			
		||||
		uri, err := h.alipayService.PayUrlMobile(order.OrderNo, notifyURL, returnURL, amount, order.Subject)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "error with generate pay url: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c.Redirect(302, uri)
 | 
			
		||||
		return
 | 
			
		||||
	} else if payWay == "hupi" { // 虎皮椒支付
 | 
			
		||||
		params := payment.HuPiPayReq{
 | 
			
		||||
			Version:      "1.1",
 | 
			
		||||
			TradeOrderId: orderNo,
 | 
			
		||||
			TotalFee:     fmt.Sprintf("%f", order.Amount),
 | 
			
		||||
			Title:        order.Subject,
 | 
			
		||||
			NotifyURL:    h.App.Config.HuPiPayConfig.NotifyURL,
 | 
			
		||||
			WapName:      "极客学长",
 | 
			
		||||
		}
 | 
			
		||||
		r, err := h.huPiPayService.Pay(params)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		c.Redirect(302, r.URL)
 | 
			
		||||
	}
 | 
			
		||||
	resp.ERROR(c, "Invalid operations")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OrderQuery 清单状态查询
 | 
			
		||||
// OrderQuery 查询订单状态
 | 
			
		||||
func (h *PaymentHandler) OrderQuery(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		OrderNo string `json:"order_no"`
 | 
			
		||||
@@ -85,7 +137,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var order model.Order
 | 
			
		||||
	res := h.db.Where("order_no = ?", data.OrderNo).First(&order)
 | 
			
		||||
	res := h.DB.Where("order_no = ?", data.OrderNo).First(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Order not found")
 | 
			
		||||
		return
 | 
			
		||||
@@ -100,7 +152,7 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
 | 
			
		||||
	for {
 | 
			
		||||
		time.Sleep(time.Second)
 | 
			
		||||
		var item model.Order
 | 
			
		||||
		h.db.Where("order_no = ?", data.OrderNo).First(&item)
 | 
			
		||||
		h.DB.Where("order_no = ?", data.OrderNo).First(&item)
 | 
			
		||||
		if counter >= 15 || item.Status == types.OrderPaidSuccess || item.Status != order.Status {
 | 
			
		||||
			order.Status = item.Status
 | 
			
		||||
			break
 | 
			
		||||
@@ -111,16 +163,12 @@ func (h *PaymentHandler) OrderQuery(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"status": order.Status})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AlipayQrcode 生成支付宝支付 URL 二维码
 | 
			
		||||
func (h *PaymentHandler) AlipayQrcode(c *gin.Context) {
 | 
			
		||||
	if !h.App.SysConfig.EnabledAlipay || h.alipayService == nil {
 | 
			
		||||
		resp.ERROR(c, "当前支付通道已经关闭,请联系管理员开通!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
// PayQrcode 生成支付 URL 二维码
 | 
			
		||||
func (h *PaymentHandler) PayQrcode(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		ProductId uint `json:"product_id"`
 | 
			
		||||
		UserId    int  `json:"user_id"`
 | 
			
		||||
		PayWay    string `json:"pay_way"` // 支付方式
 | 
			
		||||
		ProductId uint   `json:"product_id"`
 | 
			
		||||
		UserId    int    `json:"user_id"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -128,62 +176,105 @@ func (h *PaymentHandler) AlipayQrcode(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var product model.Product
 | 
			
		||||
	res := h.db.First(&product, data.ProductId)
 | 
			
		||||
	res := h.DB.First(&product, data.ProductId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Product not found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	orderNo, err := h.snowflake.Next()
 | 
			
		||||
	orderNo, err := h.snowflake.Next(false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with generate trade no: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res = h.db.First(&user, data.UserId)
 | 
			
		||||
	res = h.DB.First(&user, data.UserId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Invalid user ID")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var payWay string
 | 
			
		||||
	var notifyURL string
 | 
			
		||||
	switch data.PayWay {
 | 
			
		||||
	case "hupi":
 | 
			
		||||
		payWay = PayWayXunHu
 | 
			
		||||
		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
 | 
			
		||||
	case "payjs":
 | 
			
		||||
		payWay = PayWayJs
 | 
			
		||||
		notifyURL = h.App.Config.JPayConfig.NotifyURL
 | 
			
		||||
	default:
 | 
			
		||||
		payWay = PayWayAlipay
 | 
			
		||||
		notifyURL = h.App.Config.AlipayConfig.NotifyURL
 | 
			
		||||
	}
 | 
			
		||||
	// 创建订单
 | 
			
		||||
	remark := types.OrderRemark{
 | 
			
		||||
		Days:     product.Days,
 | 
			
		||||
		Calls:    product.Calls,
 | 
			
		||||
		Power:    product.Power,
 | 
			
		||||
		Name:     product.Name,
 | 
			
		||||
		Price:    product.Price,
 | 
			
		||||
		Discount: product.Discount,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
 | 
			
		||||
	order := model.Order{
 | 
			
		||||
		UserId:    user.Id,
 | 
			
		||||
		Mobile:    user.Mobile,
 | 
			
		||||
		Username:  user.Username,
 | 
			
		||||
		ProductId: product.Id,
 | 
			
		||||
		OrderNo:   orderNo,
 | 
			
		||||
		Subject:   product.Name,
 | 
			
		||||
		Amount:    product.Price - product.Discount,
 | 
			
		||||
		Amount:    amount,
 | 
			
		||||
		Status:    types.OrderNotPaid,
 | 
			
		||||
		PayWay:    PayWayAlipay,
 | 
			
		||||
		PayWay:    payWay,
 | 
			
		||||
		Remark:    utils.JsonEncode(remark),
 | 
			
		||||
	}
 | 
			
		||||
	res = h.db.Create(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
	res = h.DB.Create(&order)
 | 
			
		||||
	if res.Error != nil || res.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "error with create order: "+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 生成二维码图片
 | 
			
		||||
	file, err := h.fs.Open("res/img/alipay.jpg")
 | 
			
		||||
	// PayJs 单独处理,只能用官方生成的二维码
 | 
			
		||||
	if data.PayWay == "payjs" {
 | 
			
		||||
		params := payment.JPayReq{
 | 
			
		||||
			TotalFee:   int(math.Ceil(order.Amount * 100)),
 | 
			
		||||
			OutTradeNo: order.OrderNo,
 | 
			
		||||
			Subject:    product.Name,
 | 
			
		||||
		}
 | 
			
		||||
		r := h.js.Pay(params)
 | 
			
		||||
		if r.IsOK() {
 | 
			
		||||
			resp.SUCCESS(c, gin.H{"order_no": order.OrderNo, "image": r.Qrcode})
 | 
			
		||||
			return
 | 
			
		||||
		} else {
 | 
			
		||||
			resp.ERROR(c, "error with generating payment qrcode: "+r.ReturnMsg)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var logo string
 | 
			
		||||
	if data.PayWay == "alipay" {
 | 
			
		||||
		logo = "res/img/alipay.jpg"
 | 
			
		||||
	} else if data.PayWay == "hupi" {
 | 
			
		||||
		if h.App.Config.HuPiPayConfig.Name == "wechat" {
 | 
			
		||||
			logo = "res/img/wechat-pay.jpg"
 | 
			
		||||
		} else {
 | 
			
		||||
			logo = "res/img/alipay.jpg"
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	file, err := h.fs.Open(logo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		resp.ERROR(c, "error with open qrcode log file: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	parse, err := url.Parse(h.App.Config.AlipayConfig.NotifyURL)
 | 
			
		||||
 | 
			
		||||
	parse, err := url.Parse(notifyURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	imageURL := fmt.Sprintf("%s://%s/api/payment/alipay?order_no=%s", parse.Scheme, parse.Host, orderNo)
 | 
			
		||||
	imageURL := fmt.Sprintf("%s://%s/api/payment/doPay?order_no=%s&pay_way=%s", parse.Scheme, parse.Host, orderNo, data.PayWay)
 | 
			
		||||
	imgData, err := utils.GenQrcode(imageURL, 400, file)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
@@ -193,6 +284,252 @@ func (h *PaymentHandler) AlipayQrcode(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, gin.H{"order_no": orderNo, "image": fmt.Sprintf("data:image/jpg;base64, %s", imgDataBase64), "url": imageURL})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Mobile 移动端支付
 | 
			
		||||
func (h *PaymentHandler) Mobile(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		PayWay    string `json:"pay_way"` // 支付方式
 | 
			
		||||
		ProductId uint   `json:"product_id"`
 | 
			
		||||
		UserId    int    `json:"user_id"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var product model.Product
 | 
			
		||||
	res := h.DB.First(&product, data.ProductId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Product not found")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	orderNo, err := h.snowflake.Next(false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with generate trade no: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res = h.DB.First(&user, data.UserId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "Invalid user ID")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	amount, _ := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Float64()
 | 
			
		||||
	var payWay string
 | 
			
		||||
	var notifyURL, returnURL string
 | 
			
		||||
	var payURL string
 | 
			
		||||
	switch data.PayWay {
 | 
			
		||||
	case "hupi":
 | 
			
		||||
		payWay = PayWayXunHu
 | 
			
		||||
		notifyURL = h.App.Config.HuPiPayConfig.NotifyURL
 | 
			
		||||
		returnURL = h.App.Config.HuPiPayConfig.ReturnURL
 | 
			
		||||
		params := payment.HuPiPayReq{
 | 
			
		||||
			Version:      "1.1",
 | 
			
		||||
			TradeOrderId: orderNo,
 | 
			
		||||
			TotalFee:     fmt.Sprintf("%f", amount),
 | 
			
		||||
			Title:        product.Name,
 | 
			
		||||
			NotifyURL:    notifyURL,
 | 
			
		||||
			ReturnURL:    returnURL,
 | 
			
		||||
			CallbackURL:  returnURL,
 | 
			
		||||
			WapName:      "极客学长",
 | 
			
		||||
		}
 | 
			
		||||
		r, err := h.huPiPayService.Pay(params)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with generating Pay URL: ", err.Error())
 | 
			
		||||
			resp.ERROR(c, "error with generating Pay URL: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		payURL = r.URL
 | 
			
		||||
	case "payjs":
 | 
			
		||||
		payWay = PayWayJs
 | 
			
		||||
		notifyURL = h.App.Config.JPayConfig.NotifyURL
 | 
			
		||||
		returnURL = h.App.Config.JPayConfig.ReturnURL
 | 
			
		||||
		totalFee := decimal.NewFromFloat(product.Price).Sub(decimal.NewFromFloat(product.Discount)).Mul(decimal.NewFromInt(100)).IntPart()
 | 
			
		||||
		params := url.Values{}
 | 
			
		||||
		params.Add("total_fee", fmt.Sprintf("%d", totalFee))
 | 
			
		||||
		params.Add("out_trade_no", orderNo)
 | 
			
		||||
		params.Add("body", product.Name)
 | 
			
		||||
		params.Add("notify_url", notifyURL)
 | 
			
		||||
		params.Add("auto", "0")
 | 
			
		||||
		payURL = h.js.PayH5(params)
 | 
			
		||||
	case "alipay":
 | 
			
		||||
		payWay = PayWayAlipay
 | 
			
		||||
		notifyURL = h.App.Config.AlipayConfig.NotifyURL
 | 
			
		||||
		returnURL = h.App.Config.AlipayConfig.ReturnURL
 | 
			
		||||
		payURL, err = h.alipayService.PayUrlMobile(orderNo, notifyURL, returnURL, fmt.Sprintf("%.2f", amount), product.Name)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			resp.ERROR(c, "error with generating Pay URL: "+err.Error())
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	default:
 | 
			
		||||
		resp.ERROR(c, "Unsupported pay way: "+data.PayWay)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 创建订单
 | 
			
		||||
	remark := types.OrderRemark{
 | 
			
		||||
		Days:     product.Days,
 | 
			
		||||
		Power:    product.Power,
 | 
			
		||||
		Name:     product.Name,
 | 
			
		||||
		Price:    product.Price,
 | 
			
		||||
		Discount: product.Discount,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	order := model.Order{
 | 
			
		||||
		UserId:    user.Id,
 | 
			
		||||
		Username:  user.Username,
 | 
			
		||||
		ProductId: product.Id,
 | 
			
		||||
		OrderNo:   orderNo,
 | 
			
		||||
		Subject:   product.Name,
 | 
			
		||||
		Amount:    amount,
 | 
			
		||||
		Status:    types.OrderNotPaid,
 | 
			
		||||
		PayWay:    payWay,
 | 
			
		||||
		Remark:    utils.JsonEncode(remark),
 | 
			
		||||
	}
 | 
			
		||||
	res = h.DB.Create(&order)
 | 
			
		||||
	if res.Error != nil || res.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "error with create order: "+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, payURL)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// 异步通知回调公共逻辑
 | 
			
		||||
func (h *PaymentHandler) notify(orderNo string, tradeNo string) error {
 | 
			
		||||
	var order model.Order
 | 
			
		||||
	res := h.DB.Where("order_no = ?", orderNo).First(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		err := fmt.Errorf("error with fetch order: %v", res.Error)
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	// 已支付订单,直接返回
 | 
			
		||||
	if order.Status == types.OrderPaidSuccess {
 | 
			
		||||
		return nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res = h.DB.First(&user, order.UserId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		err := fmt.Errorf("error with fetch user info: %v", res.Error)
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var remark types.OrderRemark
 | 
			
		||||
	err := utils.JsonDecode(order.Remark, &remark)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		err := fmt.Errorf("error with decode order remark: %v", err)
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var opt string
 | 
			
		||||
	var power int
 | 
			
		||||
	if remark.Days > 0 { // VIP 充值
 | 
			
		||||
		if user.ExpiredTime >= time.Now().Unix() {
 | 
			
		||||
			user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
 | 
			
		||||
			opt = "VIP充值,VIP 没到期,只延期不增加算力"
 | 
			
		||||
		} else {
 | 
			
		||||
			user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
 | 
			
		||||
			user.Power += h.App.SysConfig.VipMonthPower
 | 
			
		||||
			power = h.App.SysConfig.VipMonthPower
 | 
			
		||||
			opt = "VIP充值"
 | 
			
		||||
		}
 | 
			
		||||
		user.Vip = true
 | 
			
		||||
	} else { // 充值点卡,直接增加次数即可
 | 
			
		||||
		user.Power += remark.Power
 | 
			
		||||
		opt = "点卡充值"
 | 
			
		||||
		power = remark.Power
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户信息
 | 
			
		||||
	res = h.DB.Updates(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		err := fmt.Errorf("error with update user info: %v", res.Error)
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新订单状态
 | 
			
		||||
	order.PayTime = time.Now().Unix()
 | 
			
		||||
	order.Status = types.OrderPaidSuccess
 | 
			
		||||
	order.TradeNo = tradeNo
 | 
			
		||||
	res = h.DB.Updates(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		err := fmt.Errorf("error with update order info: %v", res.Error)
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新产品销量
 | 
			
		||||
	h.DB.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
 | 
			
		||||
 | 
			
		||||
	// 记录算力充值日志
 | 
			
		||||
	if opt != "" {
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerRecharge,
 | 
			
		||||
			Amount:    power,
 | 
			
		||||
			Balance:   user.Power,
 | 
			
		||||
			Mark:      types.PowerAdd,
 | 
			
		||||
			Model:     order.PayWay,
 | 
			
		||||
			Remark:    fmt.Sprintf("%s,金额:%f,订单号:%s", opt, order.Amount, order.OrderNo),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// GetPayWays 获取支付方式
 | 
			
		||||
func (h *PaymentHandler) GetPayWays(c *gin.Context) {
 | 
			
		||||
	data := gin.H{}
 | 
			
		||||
	if h.App.Config.AlipayConfig.Enabled {
 | 
			
		||||
		data["alipay"] = gin.H{"name": "alipay"}
 | 
			
		||||
	}
 | 
			
		||||
	if h.App.Config.HuPiPayConfig.Enabled {
 | 
			
		||||
		data["hupi"] = gin.H{"name": h.App.Config.HuPiPayConfig.Name}
 | 
			
		||||
	}
 | 
			
		||||
	if h.App.Config.JPayConfig.Enabled {
 | 
			
		||||
		data["payjs"] = gin.H{"name": h.App.Config.JPayConfig.Name}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, data)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HuPiPayNotify 虎皮椒支付异步回调
 | 
			
		||||
func (h *PaymentHandler) HuPiPayNotify(c *gin.Context) {
 | 
			
		||||
	err := c.Request.ParseForm()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	orderNo := c.Request.Form.Get("trade_order_id")
 | 
			
		||||
	tradeNo := c.Request.Form.Get("open_order_id")
 | 
			
		||||
	logger.Infof("收到虎皮椒订单支付回调,订单 NO:%s,交易流水号:%s", orderNo, tradeNo)
 | 
			
		||||
 | 
			
		||||
	if err = h.huPiPayService.Check(tradeNo); err != nil {
 | 
			
		||||
		logger.Error("订单校验失败:", err)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	err = h.notify(orderNo, tradeNo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.String(http.StatusOK, "success")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// AlipayNotify 支付宝支付回调
 | 
			
		||||
func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
 | 
			
		||||
	err := c.Request.ParseForm()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -200,74 +537,55 @@ func (h *PaymentHandler) AlipayNotify(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// TODO:这里最好用支付宝的公钥签名签证一下交易真假
 | 
			
		||||
	//res := h.alipayService.TradeVerify(c.Request.Form)
 | 
			
		||||
	r := h.alipayService.TradeQuery(c.Request.Form.Get("out_trade_no"))
 | 
			
		||||
	logger.Infof("验证支付结果:%+v", r)
 | 
			
		||||
	if !r.Success() {
 | 
			
		||||
	// TODO:验证交易签名
 | 
			
		||||
	res := h.alipayService.TradeVerify(c.Request.Form)
 | 
			
		||||
	logger.Infof("验证支付结果:%+v", res)
 | 
			
		||||
	if !res.Success() {
 | 
			
		||||
		logger.Error("订单校验失败:", res.Message)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	var order model.Order
 | 
			
		||||
	res := h.db.Where("order_no = ?", r.OutTradeNo).First(&order)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error(res.Error)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res = h.db.First(&user, order.UserId)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error(res.Error)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var remark types.OrderRemark
 | 
			
		||||
	err = utils.JsonDecode(order.Remark, &remark)
 | 
			
		||||
	tradeNo := c.Request.Form.Get("trade_no")
 | 
			
		||||
	err = h.notify(res.OutTradeNo, tradeNo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	c.String(http.StatusOK, "success")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PayJsNotify PayJs 支付异步回调
 | 
			
		||||
func (h *PaymentHandler) PayJsNotify(c *gin.Context) {
 | 
			
		||||
	err := c.Request.ParseForm()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	orderNo := c.Request.Form.Get("out_trade_no")
 | 
			
		||||
	returnCode := c.Request.Form.Get("return_code")
 | 
			
		||||
	logger.Infof("收到订单支付回调,订单 NO:%s,支付结果代码:%v", orderNo, returnCode)
 | 
			
		||||
	// 支付失败
 | 
			
		||||
	if returnCode != "1" {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 校验订单支付状态
 | 
			
		||||
	tradeNo := c.Request.Form.Get("payjs_order_id")
 | 
			
		||||
	err = h.js.Check(tradeNo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("订单校验失败:", err)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	err = h.notify(orderNo, tradeNo)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(res.Error)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// 1. 点卡:days == 0, calls > 0
 | 
			
		||||
	// 2. vip 套餐:days > 0, calls == 0
 | 
			
		||||
	if remark.Days > 0 {
 | 
			
		||||
		if user.ExpiredTime > time.Now().Unix() {
 | 
			
		||||
			user.ExpiredTime = time.Unix(user.ExpiredTime, 0).AddDate(0, 0, remark.Days).Unix()
 | 
			
		||||
		} else {
 | 
			
		||||
			user.ExpiredTime = time.Now().AddDate(0, 0, remark.Days).Unix()
 | 
			
		||||
		}
 | 
			
		||||
		user.Vip = true
 | 
			
		||||
 | 
			
		||||
	} else if !user.Vip { // 充值点卡的非 VIP 用户
 | 
			
		||||
		user.ExpiredTime = time.Now().AddDate(0, 0, 30).Unix()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if remark.Calls > 0 { // 充值点卡
 | 
			
		||||
		user.Calls += remark.Calls
 | 
			
		||||
	} else {
 | 
			
		||||
		user.Calls += h.App.SysConfig.VipMonthCalls
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户信息
 | 
			
		||||
	res = h.db.Updates(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error(res.Error)
 | 
			
		||||
		c.String(http.StatusOK, "fail")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新订单状态
 | 
			
		||||
	order.PayTime = time.Now().Unix()
 | 
			
		||||
	order.Status = types.OrderPaidSuccess
 | 
			
		||||
	h.db.Updates(&order)
 | 
			
		||||
 | 
			
		||||
	// 更新产品销量
 | 
			
		||||
	h.db.Model(&model.Product{}).Where("id = ?", order.ProductId).UpdateColumn("sales", gorm.Expr("sales + ?", 1))
 | 
			
		||||
 | 
			
		||||
	c.String(http.StatusOK, "success")
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										74
									
								
								api/handler/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										74
									
								
								api/handler/power_log_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,74 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PowerLogHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPowerLogHandler(app *core.AppServer, db *gorm.DB) *PowerLogHandler {
 | 
			
		||||
	return &PowerLogHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *PowerLogHandler) List(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Model    string   `json:"model"`
 | 
			
		||||
		Date     []string `json:"date"`
 | 
			
		||||
		Page     int      `json:"page"`
 | 
			
		||||
		PageSize int      `json:"page_size"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	session = session.Where("user_id", userId)
 | 
			
		||||
	if data.Model != "" {
 | 
			
		||||
		session = session.Where("model", data.Model)
 | 
			
		||||
	}
 | 
			
		||||
	if len(data.Date) == 2 {
 | 
			
		||||
		start := data.Date[0] + " 00:00:00"
 | 
			
		||||
		end := data.Date[1] + " 00:00:00"
 | 
			
		||||
		session = session.Where("created_at >= ? AND created_at <= ?", start, end)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var total int64
 | 
			
		||||
	session.Model(&model.PowerLog{}).Count(&total)
 | 
			
		||||
	var items []model.PowerLog
 | 
			
		||||
	var list = make([]vo.PowerLog, 0)
 | 
			
		||||
	offset := (data.Page - 1) * data.PageSize
 | 
			
		||||
	res := session.Order("id DESC").Offset(offset).Limit(data.PageSize).Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var log vo.PowerLog
 | 
			
		||||
			err := utils.CopyObject(item, &log)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			log.Id = item.Id
 | 
			
		||||
			log.CreatedAt = item.CreatedAt.Unix()
 | 
			
		||||
			log.TypeStr = item.Type.String()
 | 
			
		||||
			list = append(list, log)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, vo.NewPage(total, data.Page, data.PageSize, list))
 | 
			
		||||
}
 | 
			
		||||
@@ -1,31 +1,35 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type ProductHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewProductHandler(app *core.AppServer, db *gorm.DB) *ProductHandler {
 | 
			
		||||
	h := ProductHandler{db: db}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
	return &ProductHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// List 模型列表
 | 
			
		||||
func (h *ProductHandler) List(c *gin.Context) {
 | 
			
		||||
	var items []model.Product
 | 
			
		||||
	var list = make([]vo.Product, 0)
 | 
			
		||||
	res := h.db.Where("enabled", true).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	res := h.DB.Where("enabled", true).Order("sort_num ASC").Find(&items)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		for _, item := range items {
 | 
			
		||||
			var product vo.Product
 | 
			
		||||
 
 | 
			
		||||
@@ -1,25 +1,35 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"math"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"sync"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type RewardHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db *gorm.DB
 | 
			
		||||
	lock sync.Mutex
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewRewardHandler(server *core.AppServer, db *gorm.DB) *RewardHandler {
 | 
			
		||||
	h := RewardHandler{db: db}
 | 
			
		||||
	h.App = server
 | 
			
		||||
	return &h
 | 
			
		||||
func NewRewardHandler(app *core.AppServer, db *gorm.DB) *RewardHandler {
 | 
			
		||||
	return &RewardHandler{BaseHandler: BaseHandler{App: app, DB: db}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Verify 打赏码核销
 | 
			
		||||
@@ -32,11 +42,20 @@ func (h *RewardHandler) Verify(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.HACKER(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 移除转账单号中间的空格,防止有人复制的时候多复制了空格
 | 
			
		||||
	data.TxId = strings.ReplaceAll(data.TxId, " ", "")
 | 
			
		||||
 | 
			
		||||
	h.lock.Lock()
 | 
			
		||||
	defer h.lock.Unlock()
 | 
			
		||||
 | 
			
		||||
	var item model.Reward
 | 
			
		||||
	res := h.db.Where("tx_id = ?", data.TxId).First(&item)
 | 
			
		||||
	res := h.DB.Where("tx_id = ?", data.TxId).First(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "无效的众筹交易流水号!")
 | 
			
		||||
		return
 | 
			
		||||
@@ -47,16 +66,13 @@ func (h *RewardHandler) Verify(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.HACKER(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	tx := h.db.Begin()
 | 
			
		||||
	calls := (item.Amount + 0.1) * 10
 | 
			
		||||
	res = h.db.Model(&user).UpdateColumn("calls", gorm.Expr("calls + ?", calls))
 | 
			
		||||
	tx := h.DB.Begin()
 | 
			
		||||
	exchange := vo.RewardExchange{}
 | 
			
		||||
	power := math.Ceil(item.Amount / h.App.SysConfig.PowerPrice)
 | 
			
		||||
	exchange.Power = int(power)
 | 
			
		||||
	res = tx.Model(&user).UpdateColumn("power", gorm.Expr("power + ?", exchange.Power))
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
@@ -64,13 +80,26 @@ func (h *RewardHandler) Verify(c *gin.Context) {
 | 
			
		||||
	// 更新核销状态
 | 
			
		||||
	item.Status = true
 | 
			
		||||
	item.UserId = user.Id
 | 
			
		||||
	res = h.db.Updates(&item)
 | 
			
		||||
	item.Exchange = utils.JsonEncode(exchange)
 | 
			
		||||
	res = tx.Updates(&item)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		tx.Rollback()
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 记录算力充值日志
 | 
			
		||||
	h.DB.Create(&model.PowerLog{
 | 
			
		||||
		UserId:    user.Id,
 | 
			
		||||
		Username:  user.Username,
 | 
			
		||||
		Type:      types.PowerReward,
 | 
			
		||||
		Amount:    exchange.Power,
 | 
			
		||||
		Balance:   user.Power + exchange.Power,
 | 
			
		||||
		Mark:      types.PowerAdd,
 | 
			
		||||
		Model:     "众筹支付",
 | 
			
		||||
		Remark:    fmt.Sprintf("众筹充值算力,金额:%f,价格:%f", item.Amount, h.App.SysConfig.PowerPrice),
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	})
 | 
			
		||||
	tx.Commit()
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -1,37 +1,54 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/sd"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type SdJobHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	redis   *redis.Client
 | 
			
		||||
	db      *gorm.DB
 | 
			
		||||
	service *sd.Service
 | 
			
		||||
	redis     *redis.Client
 | 
			
		||||
	pool      *sd.ServicePool
 | 
			
		||||
	uploader  *oss.UploaderManager
 | 
			
		||||
	snowflake *service.Snowflake
 | 
			
		||||
	leveldb   *store.LevelDB
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSdJobHandler(app *core.AppServer, redisCli *redis.Client, db *gorm.DB, service *sd.Service) *SdJobHandler {
 | 
			
		||||
	h := SdJobHandler{
 | 
			
		||||
		redis:   redisCli,
 | 
			
		||||
		db:      db,
 | 
			
		||||
		service: service,
 | 
			
		||||
func NewSdJobHandler(app *core.AppServer, db *gorm.DB, pool *sd.ServicePool, manager *oss.UploaderManager, snowflake *service.Snowflake, levelDB *store.LevelDB) *SdJobHandler {
 | 
			
		||||
	return &SdJobHandler{
 | 
			
		||||
		pool:      pool,
 | 
			
		||||
		uploader:  manager,
 | 
			
		||||
		snowflake: snowflake,
 | 
			
		||||
		leveldb:   levelDB,
 | 
			
		||||
		BaseHandler: BaseHandler{
 | 
			
		||||
			App: app,
 | 
			
		||||
			DB:  db,
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
	h.App = app
 | 
			
		||||
	return &h
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Client WebSocket 客户端,用于通知任务状态变更
 | 
			
		||||
@@ -39,25 +56,36 @@ func (h *SdJobHandler) Client(c *gin.Context) {
 | 
			
		||||
	ws, err := (&websocket.Upgrader{CheckOrigin: func(r *http.Request) bool { return true }}).Upgrade(c.Writer, c.Request, nil)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error(err)
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
	if userId == 0 {
 | 
			
		||||
		logger.Info("Invalid user ID")
 | 
			
		||||
		c.Abort()
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	sessionId := c.Query("session_id")
 | 
			
		||||
	client := types.NewWsClient(ws)
 | 
			
		||||
	// 删除旧的连接
 | 
			
		||||
	h.service.Clients.Put(sessionId, client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.ClientIP())
 | 
			
		||||
	h.pool.Clients.Put(uint(userId), client)
 | 
			
		||||
	logger.Infof("New websocket connected, IP: %s", c.RemoteIP())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
func (h *SdJobHandler) preCheck(c *gin.Context) bool {
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.ImgCalls <= 0 {
 | 
			
		||||
		resp.ERROR(c, "您的绘图次数不足,请联系管理员充值!")
 | 
			
		||||
	if !h.pool.HasAvailableService() {
 | 
			
		||||
		resp.ERROR(c, "Stable-Diffusion 池子中没有没有可用的服务!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if user.Power < h.App.SysConfig.SdPower {
 | 
			
		||||
		resp.ERROR(c, "当前用户剩余算力不足以完成本次绘画!")
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -67,12 +95,7 @@ func (h *SdJobHandler) checkLimits(c *gin.Context) bool {
 | 
			
		||||
 | 
			
		||||
// Image 创建一个绘画任务
 | 
			
		||||
func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
	if !h.App.Config.SdConfig.Enabled {
 | 
			
		||||
		resp.ERROR(c, "Stable Diffusion service is disabled")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if !h.checkLimits(c) {
 | 
			
		||||
	if !h.preCheck(c) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
@@ -105,23 +128,29 @@ func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
	idValue, _ := c.Get(types.LoginUserID)
 | 
			
		||||
	userId := utils.IntValue(utils.InterfaceToString(idValue), 0)
 | 
			
		||||
	params := types.SdTaskParams{
 | 
			
		||||
		TaskId:         fmt.Sprintf("task(%s)", utils.RandString(15)),
 | 
			
		||||
		Prompt:         data.Prompt,
 | 
			
		||||
		NegativePrompt: data.NegativePrompt,
 | 
			
		||||
		Steps:          data.Steps,
 | 
			
		||||
		Sampler:        data.Sampler,
 | 
			
		||||
		FaceFix:        data.FaceFix,
 | 
			
		||||
		CfgScale:       data.CfgScale,
 | 
			
		||||
		Seed:           data.Seed,
 | 
			
		||||
		Height:         data.Height,
 | 
			
		||||
		Width:          data.Width,
 | 
			
		||||
		HdFix:          data.HdFix,
 | 
			
		||||
		HdRedrawRate:   data.HdRedrawRate,
 | 
			
		||||
		HdScale:        data.HdScale,
 | 
			
		||||
		HdScaleAlg:     data.HdScaleAlg,
 | 
			
		||||
		HdSteps:        data.HdSteps,
 | 
			
		||||
	taskId, err := h.snowflake.Next(true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "error with generate task id: "+err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	params := types.SdTaskParams{
 | 
			
		||||
		TaskId:       taskId,
 | 
			
		||||
		Prompt:       data.Prompt,
 | 
			
		||||
		NegPrompt:    data.NegPrompt,
 | 
			
		||||
		Steps:        data.Steps,
 | 
			
		||||
		Sampler:      data.Sampler,
 | 
			
		||||
		FaceFix:      data.FaceFix,
 | 
			
		||||
		CfgScale:     data.CfgScale,
 | 
			
		||||
		Seed:         data.Seed,
 | 
			
		||||
		Height:       data.Height,
 | 
			
		||||
		Width:        data.Width,
 | 
			
		||||
		HdFix:        data.HdFix,
 | 
			
		||||
		HdRedrawRate: data.HdRedrawRate,
 | 
			
		||||
		HdScale:      data.HdScale,
 | 
			
		||||
		HdScaleAlg:   data.HdScaleAlg,
 | 
			
		||||
		HdSteps:      data.HdSteps,
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	job := model.SdJob{
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
		Type:      types.TaskImage.String(),
 | 
			
		||||
@@ -129,45 +158,84 @@ func (h *SdJobHandler) Image(c *gin.Context) {
 | 
			
		||||
		Params:    utils.JsonEncode(params),
 | 
			
		||||
		Prompt:    data.Prompt,
 | 
			
		||||
		Progress:  0,
 | 
			
		||||
		Started:   false,
 | 
			
		||||
		Power:     h.App.SysConfig.SdPower,
 | 
			
		||||
		CreatedAt: time.Now(),
 | 
			
		||||
	}
 | 
			
		||||
	res := h.db.Create(&job)
 | 
			
		||||
	res := h.DB.Create(&job)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "error with save job: "+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.service.PushTask(types.SdTask{
 | 
			
		||||
	h.pool.PushTask(types.SdTask{
 | 
			
		||||
		Id:        int(job.Id),
 | 
			
		||||
		SessionId: data.SessionId,
 | 
			
		||||
		Src:       types.TaskSrcImg,
 | 
			
		||||
		Type:      types.TaskImage,
 | 
			
		||||
		Prompt:    data.Prompt,
 | 
			
		||||
		Params:    params,
 | 
			
		||||
		UserId:    userId,
 | 
			
		||||
	})
 | 
			
		||||
	var jobVo vo.SdJob
 | 
			
		||||
	err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		// 推送任务到前端
 | 
			
		||||
		client := h.service.Clients.Get(data.SessionId)
 | 
			
		||||
		if client != nil {
 | 
			
		||||
			utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(uint(job.UserId))
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte("Task Updated"))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update user's power
 | 
			
		||||
	tx := h.DB.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power - ?", job.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		user, _ := h.GetLoginUser(c)
 | 
			
		||||
		h.DB.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    job.Power,
 | 
			
		||||
			Balance:   user.Power - job.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "stable-diffusion",
 | 
			
		||||
			Remark:    fmt.Sprintf("绘图操作,任务ID:%s", job.TaskId),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 stable diffusion 任务列表
 | 
			
		||||
func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetInt(c, "status", 0)
 | 
			
		||||
	userId := h.GetInt(c, "user_id", 0)
 | 
			
		||||
// ImgWall 照片墙
 | 
			
		||||
func (h *SdJobHandler) ImgWall(c *gin.Context) {
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	err, jobs := h.getData(true, 0, page, pageSize, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	session := h.db.Session(&gorm.Session{})
 | 
			
		||||
	if status == 1 {
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 SD 任务列表
 | 
			
		||||
func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	status := h.GetBool(c, "status")
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	page := h.GetInt(c, "page", 0)
 | 
			
		||||
	pageSize := h.GetInt(c, "page_size", 0)
 | 
			
		||||
	publish := h.GetBool(c, "publish")
 | 
			
		||||
 | 
			
		||||
	err, jobs := h.getData(status, userId, page, pageSize, publish)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// JobList 获取 MJ 任务列表
 | 
			
		||||
func (h *SdJobHandler) getData(finish bool, userId uint, page int, pageSize int, publish bool) (error, []vo.SdJob) {
 | 
			
		||||
 | 
			
		||||
	session := h.DB.Session(&gorm.Session{})
 | 
			
		||||
	if finish {
 | 
			
		||||
		session = session.Where("progress = ?", 100).Order("id DESC")
 | 
			
		||||
	} else {
 | 
			
		||||
		session = session.Where("progress < ?", 100).Order("id ASC")
 | 
			
		||||
@@ -175,6 +243,9 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	if userId > 0 {
 | 
			
		||||
		session = session.Where("user_id = ?", userId)
 | 
			
		||||
	}
 | 
			
		||||
	if publish {
 | 
			
		||||
		session = session.Where("publish", publish)
 | 
			
		||||
	}
 | 
			
		||||
	if page > 0 && pageSize > 0 {
 | 
			
		||||
		offset := (page - 1) * pageSize
 | 
			
		||||
		session = session.Offset(offset).Limit(pageSize)
 | 
			
		||||
@@ -183,8 +254,7 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
	var items []model.SdJob
 | 
			
		||||
	res := session.Find(&items)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, types.NoData)
 | 
			
		||||
		return
 | 
			
		||||
		return res.Error, nil
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var jobs = make([]vo.SdJob, 0)
 | 
			
		||||
@@ -194,14 +264,70 @@ func (h *SdJobHandler) JobList(c *gin.Context) {
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if item.Progress < 100 {
 | 
			
		||||
			// 30 分钟还没完成的任务直接删除
 | 
			
		||||
			if time.Now().Sub(item.CreatedAt) > time.Minute*30 {
 | 
			
		||||
				h.db.Delete(&item)
 | 
			
		||||
				continue
 | 
			
		||||
			// 从 leveldb 中获取图片预览数据
 | 
			
		||||
			var imageData string
 | 
			
		||||
			err = h.leveldb.Get(item.TaskId, &imageData)
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				job.ImgURL = "data:image/png;base64," + imageData
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		jobs = append(jobs, job)
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c, jobs)
 | 
			
		||||
 | 
			
		||||
	return nil, jobs
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove task image
 | 
			
		||||
func (h *SdJobHandler) Remove(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint   `json:"id"`
 | 
			
		||||
		UserId uint   `json:"user_id"`
 | 
			
		||||
		ImgURL string `json:"img_url"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove job recode
 | 
			
		||||
	res := h.DB.Delete(&model.SdJob{Id: data.Id})
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove image
 | 
			
		||||
	err := h.uploader.GetUploadHandler().Delete(data.ImgURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("remove image failed: ", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	client := h.pool.Clients.Get(data.UserId)
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		_ = client.Send([]byte(sd.Finished))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Publish 发布/取消发布图片到画廊显示
 | 
			
		||||
func (h *SdJobHandler) Publish(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Id     uint `json:"id"`
 | 
			
		||||
		Action bool `json:"action"` // 发布动作,true => 发布,false => 取消分享
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res := h.DB.Model(&model.SdJob{Id: data.Id}).UpdateColumn("publish", true)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,36 +1,55 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/sms"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"strings"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
const CodeStorePrefix = "/verify/codes/"
 | 
			
		||||
 | 
			
		||||
type SmsHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	leveldb *store.LevelDB
 | 
			
		||||
	sms     *service.AliYunSmsService
 | 
			
		||||
	redis   *redis.Client
 | 
			
		||||
	sms     *sms.ServiceManager
 | 
			
		||||
	smtp    *service.SmtpService
 | 
			
		||||
	captcha *service.CaptchaService
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewSmsHandler(app *core.AppServer, db *store.LevelDB, sms *service.AliYunSmsService, captcha *service.CaptchaService) *SmsHandler {
 | 
			
		||||
	handler := &SmsHandler{leveldb: db, sms: sms, captcha: captcha}
 | 
			
		||||
	handler.App = app
 | 
			
		||||
	return handler
 | 
			
		||||
func NewSmsHandler(
 | 
			
		||||
	app *core.AppServer,
 | 
			
		||||
	client *redis.Client,
 | 
			
		||||
	sms *sms.ServiceManager,
 | 
			
		||||
	smtp *service.SmtpService,
 | 
			
		||||
	captcha *service.CaptchaService) *SmsHandler {
 | 
			
		||||
	return &SmsHandler{
 | 
			
		||||
		redis:       client,
 | 
			
		||||
		sms:         sms,
 | 
			
		||||
		captcha:     captcha,
 | 
			
		||||
		smtp:        smtp,
 | 
			
		||||
		BaseHandler: BaseHandler{App: app}}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SendCode 发送验证码短信
 | 
			
		||||
// SendCode 发送验证码
 | 
			
		||||
func (h *SmsHandler) SendCode(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Mobile string `json:"mobile"`
 | 
			
		||||
		Key    string `json:"key"`
 | 
			
		||||
		Dots   string `json:"dots"`
 | 
			
		||||
		Receiver string `json:"receiver"` // 接收者
 | 
			
		||||
		Key      string `json:"key"`
 | 
			
		||||
		Dots     string `json:"dots"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
@@ -43,14 +62,28 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	code := utils.RandomNumber(6)
 | 
			
		||||
	err := h.sms.SendVerifyCode(data.Mobile, code)
 | 
			
		||||
	var err error
 | 
			
		||||
	if strings.Contains(data.Receiver, "@") { // email
 | 
			
		||||
		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "email") {
 | 
			
		||||
			resp.ERROR(c, "系统已禁用邮箱注册!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = h.smtp.SendVerifyCode(data.Receiver, code)
 | 
			
		||||
	} else {
 | 
			
		||||
		if !utils.ContainsStr(h.App.SysConfig.RegisterWays, "mobile") {
 | 
			
		||||
			resp.ERROR(c, "系统已禁用手机号注册!")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		err = h.sms.GetService().SendVerifyCode(data.Receiver, code)
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 存储验证码,等待后面注册验证
 | 
			
		||||
	err = h.leveldb.Put(CodeStorePrefix+data.Mobile, code)
 | 
			
		||||
	_, err = h.redis.Set(c, CodeStorePrefix+data.Receiver, code, 0).Result()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, "验证码保存失败")
 | 
			
		||||
		return
 | 
			
		||||
@@ -58,13 +91,3 @@ func (h *SmsHandler) SendCode(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type statusVo struct {
 | 
			
		||||
	EnabledMsgService bool `json:"enabled_msg_service"`
 | 
			
		||||
	EnabledRegister   bool `json:"enabled_register"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Status check if the message service is enabled
 | 
			
		||||
func (h *SmsHandler) Status(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c, statusVo{EnabledMsgService: h.App.SysConfig.EnabledMsg, EnabledRegister: h.App.SysConfig.EnabledRegister})
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										17
									
								
								api/handler/test_handler.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								api/handler/test_handler.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,17 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/payment"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type TestHandler struct {
 | 
			
		||||
	db        *gorm.DB
 | 
			
		||||
	snowflake *service.Snowflake
 | 
			
		||||
	js        *payment.PayJS
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewTestHandler(db *gorm.DB, snowflake *service.Snowflake, js *payment.PayJS) *TestHandler {
 | 
			
		||||
	return &TestHandler{db: db, snowflake: snowflake, js: js}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,31 +1,101 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type UploadHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db              *gorm.DB
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewUploadHandler(app *core.AppServer, db *gorm.DB, manager *oss.UploaderManager) *UploadHandler {
 | 
			
		||||
	handler := &UploadHandler{db: db, uploaderManager: manager}
 | 
			
		||||
	handler.App = app
 | 
			
		||||
	return handler
 | 
			
		||||
	return &UploadHandler{BaseHandler: BaseHandler{App: app, DB: db}, uploaderManager: manager}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *UploadHandler) Upload(c *gin.Context) {
 | 
			
		||||
	fileURL, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
 | 
			
		||||
	file, err := h.uploaderManager.GetUploadHandler().PutFile(c, "file")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.ERROR(c, err.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, fileURL)
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	res := h.DB.Create(&model.File{
 | 
			
		||||
		UserId:    int(userId),
 | 
			
		||||
		Name:      file.Name,
 | 
			
		||||
		ObjKey:    file.ObjKey,
 | 
			
		||||
		URL:       file.URL,
 | 
			
		||||
		Ext:       file.Ext,
 | 
			
		||||
		Size:      file.Size,
 | 
			
		||||
		CreatedAt: time.Time{},
 | 
			
		||||
	})
 | 
			
		||||
	if res.Error != nil || res.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "error with update database: "+res.Error.Error())
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, file)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *UploadHandler) List(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	var items []model.File
 | 
			
		||||
	var files = make([]vo.File, 0)
 | 
			
		||||
	h.DB.Where("user_id = ?", userId).Find(&items)
 | 
			
		||||
	if len(items) > 0 {
 | 
			
		||||
		for _, v := range items {
 | 
			
		||||
			var file vo.File
 | 
			
		||||
			err := utils.CopyObject(v, &file)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error(err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			file.CreatedAt = v.CreatedAt.Unix()
 | 
			
		||||
			files = append(files, file)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	resp.SUCCESS(c, files)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Remove remove files
 | 
			
		||||
func (h *UploadHandler) Remove(c *gin.Context) {
 | 
			
		||||
	userId := h.GetLoginUserId(c)
 | 
			
		||||
	id := h.GetInt(c, "id", 0)
 | 
			
		||||
	var file model.File
 | 
			
		||||
	tx := h.DB.Where("user_id = ? AND id = ?", userId, id).First(&file)
 | 
			
		||||
	if tx.Error != nil || file.Id == 0 {
 | 
			
		||||
		resp.ERROR(c, "file not existed")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// remove database
 | 
			
		||||
	tx = h.DB.Model(&model.File{}).Delete("id = ?", id)
 | 
			
		||||
	if tx.Error != nil || tx.RowsAffected == 0 {
 | 
			
		||||
		resp.ERROR(c, "failed to update database")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// remove files
 | 
			
		||||
	objectKey := file.ObjKey
 | 
			
		||||
	if objectKey == "" {
 | 
			
		||||
		objectKey = file.URL
 | 
			
		||||
	}
 | 
			
		||||
	_ = h.uploaderManager.GetUploadHandler().Delete(objectKey)
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,19 +1,26 @@
 | 
			
		||||
package handler
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"chatplus/utils/resp"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/store/vo"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"geekai/utils/resp"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"github.com/golang-jwt/jwt/v5"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/lionsoul2014/ip2region/binding/golang/xdb"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -21,9 +28,7 @@ import (
 | 
			
		||||
 | 
			
		||||
type UserHandler struct {
 | 
			
		||||
	BaseHandler
 | 
			
		||||
	db       *gorm.DB
 | 
			
		||||
	searcher *xdb.Searcher
 | 
			
		||||
	leveldb  *store.LevelDB
 | 
			
		||||
	redis    *redis.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -31,85 +36,117 @@ func NewUserHandler(
 | 
			
		||||
	app *core.AppServer,
 | 
			
		||||
	db *gorm.DB,
 | 
			
		||||
	searcher *xdb.Searcher,
 | 
			
		||||
	levelDB *store.LevelDB,
 | 
			
		||||
	client *redis.Client) *UserHandler {
 | 
			
		||||
	handler := &UserHandler{db: db, searcher: searcher, leveldb: levelDB, redis: client}
 | 
			
		||||
	handler.App = app
 | 
			
		||||
	return handler
 | 
			
		||||
	return &UserHandler{
 | 
			
		||||
		BaseHandler: BaseHandler{DB: db, App: app},
 | 
			
		||||
		searcher:    searcher,
 | 
			
		||||
		redis:       client,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Register user register
 | 
			
		||||
func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
	// parameters process
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Mobile   string `json:"mobile"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
		Code     int    `json:"code"`
 | 
			
		||||
		RegWay     string `json:"reg_way"`
 | 
			
		||||
		Username   string `json:"username"`
 | 
			
		||||
		Password   string `json:"password"`
 | 
			
		||||
		Code       string `json:"code"`
 | 
			
		||||
		InviteCode string `json:"invite_code"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	data.Password = strings.TrimSpace(data.Password)
 | 
			
		||||
 | 
			
		||||
	if len(data.Mobile) < 10 {
 | 
			
		||||
		resp.ERROR(c, "请输入合法的手机号")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	if len(data.Password) < 8 {
 | 
			
		||||
		resp.ERROR(c, "密码长度不能少于8个字符")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查验证码
 | 
			
		||||
	key := CodeStorePrefix + data.Mobile
 | 
			
		||||
	if h.App.SysConfig.EnabledMsg {
 | 
			
		||||
		var code int
 | 
			
		||||
		err := h.leveldb.Get(key, &code)
 | 
			
		||||
	var key string
 | 
			
		||||
	if data.RegWay == "email" || data.RegWay == "mobile" {
 | 
			
		||||
		key = CodeStorePrefix + data.Username
 | 
			
		||||
		code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
		if err != nil || code != data.Code {
 | 
			
		||||
			resp.ERROR(c, "短信验证码错误")
 | 
			
		||||
			resp.ERROR(c, "验证码错误")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 验证邀请码
 | 
			
		||||
	inviteCode := model.InviteCode{}
 | 
			
		||||
	if data.InviteCode != "" {
 | 
			
		||||
		res := h.DB.Where("code = ?", data.InviteCode).First(&inviteCode)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			resp.ERROR(c, "无效的邀请码")
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// check if the username is exists
 | 
			
		||||
	var item model.User
 | 
			
		||||
	res := h.db.Where("mobile = ?", data.Mobile).First(&item)
 | 
			
		||||
	if res.RowsAffected > 0 {
 | 
			
		||||
		resp.ERROR(c, "该手机号码已经被注册,请更换其他手机号")
 | 
			
		||||
	res := h.DB.Where("username = ?", data.Username).First(&item)
 | 
			
		||||
	if item.Id > 0 {
 | 
			
		||||
		resp.ERROR(c, "该用户名已经被注册")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	salt := utils.RandString(8)
 | 
			
		||||
	user := model.User{
 | 
			
		||||
		Username:   data.Username,
 | 
			
		||||
		Password:   utils.GenPassword(data.Password, salt),
 | 
			
		||||
		Nickname:   fmt.Sprintf("极客学长@%d", utils.RandomNumber(6)),
 | 
			
		||||
		Avatar:     "/images/avatar/user.png",
 | 
			
		||||
		Salt:       salt,
 | 
			
		||||
		Status:     true,
 | 
			
		||||
		Mobile:     data.Mobile,
 | 
			
		||||
		ChatRoles:  utils.JsonEncode([]string{"gpt"}),               // 默认只订阅通用助手角色
 | 
			
		||||
		ChatModels: utils.JsonEncode(h.App.SysConfig.DefaultModels), // 默认开通的模型
 | 
			
		||||
		ChatConfig: utils.JsonEncode(types.UserChatConfig{
 | 
			
		||||
			ApiKeys: map[types.Platform]string{
 | 
			
		||||
				types.OpenAI:  "",
 | 
			
		||||
				types.Azure:   "",
 | 
			
		||||
				types.ChatGLM: "",
 | 
			
		||||
			},
 | 
			
		||||
		}),
 | 
			
		||||
		Calls:    h.App.SysConfig.UserInitCalls,
 | 
			
		||||
		ImgCalls: h.App.SysConfig.InitImgCalls,
 | 
			
		||||
		Power:      h.App.SysConfig.InitPower,
 | 
			
		||||
	}
 | 
			
		||||
	res = h.db.Create(&user)
 | 
			
		||||
 | 
			
		||||
	res = h.DB.Create(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "保存数据失败")
 | 
			
		||||
		logger.Error(res.Error)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if h.App.SysConfig.EnabledMsg {
 | 
			
		||||
		_ = h.leveldb.Delete(key) // 注册成功,删除短信验证码
 | 
			
		||||
	// 记录邀请关系
 | 
			
		||||
	if data.InviteCode != "" {
 | 
			
		||||
		// 增加邀请数量
 | 
			
		||||
		h.DB.Model(&model.InviteCode{}).Where("code = ?", data.InviteCode).UpdateColumn("reg_num", gorm.Expr("reg_num + ?", 1))
 | 
			
		||||
		if h.App.SysConfig.InvitePower > 0 {
 | 
			
		||||
			h.DB.Model(&model.User{}).Where("id = ?", inviteCode.UserId).UpdateColumn("power", gorm.Expr("power + ?", h.App.SysConfig.InvitePower))
 | 
			
		||||
			// 记录邀请算力充值日志
 | 
			
		||||
			var inviter model.User
 | 
			
		||||
			h.DB.Where("id", inviteCode.UserId).First(&inviter)
 | 
			
		||||
			h.DB.Create(&model.PowerLog{
 | 
			
		||||
				UserId:    inviter.Id,
 | 
			
		||||
				Username:  inviter.Username,
 | 
			
		||||
				Type:      types.PowerInvite,
 | 
			
		||||
				Amount:    h.App.SysConfig.InvitePower,
 | 
			
		||||
				Balance:   inviter.Power,
 | 
			
		||||
				Mark:      types.PowerAdd,
 | 
			
		||||
				Model:     "",
 | 
			
		||||
				Remark:    fmt.Sprintf("邀请用户注册奖励,金额:%d,邀请码:%s,新用户:%s", h.App.SysConfig.InvitePower, inviteCode.Code, user.Username),
 | 
			
		||||
				CreatedAt: time.Now(),
 | 
			
		||||
			})
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 添加邀请记录
 | 
			
		||||
		h.DB.Create(&model.InviteLog{
 | 
			
		||||
			InviterId:  inviteCode.UserId,
 | 
			
		||||
			UserId:     user.Id,
 | 
			
		||||
			Username:   user.Username,
 | 
			
		||||
			InviteCode: inviteCode.Code,
 | 
			
		||||
			Remark:     fmt.Sprintf("奖励 %d 算力", h.App.SysConfig.InvitePower),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ = h.redis.Del(c, key) // 注册成功,删除短信验证码
 | 
			
		||||
 | 
			
		||||
	// 自动登录创建 token
 | 
			
		||||
	token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{
 | 
			
		||||
		"user_id": user.Id,
 | 
			
		||||
@@ -132,7 +169,7 @@ func (h *UserHandler) Register(c *gin.Context) {
 | 
			
		||||
// Login 用户登录
 | 
			
		||||
func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Mobile   string `json:"username"`
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
@@ -140,7 +177,7 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.db.Where("mobile = ?", data.Mobile).First(&user)
 | 
			
		||||
	res := h.DB.Where("username = ?", data.Username).First(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "用户名不存在")
 | 
			
		||||
		return
 | 
			
		||||
@@ -160,11 +197,11 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
	// 更新最后登录时间和IP
 | 
			
		||||
	user.LastLoginIp = c.ClientIP()
 | 
			
		||||
	user.LastLoginAt = time.Now().Unix()
 | 
			
		||||
	h.db.Model(&user).Updates(user)
 | 
			
		||||
	h.DB.Model(&user).Updates(user)
 | 
			
		||||
 | 
			
		||||
	h.db.Create(&model.UserLoginLog{
 | 
			
		||||
	h.DB.Create(&model.UserLoginLog{
 | 
			
		||||
		UserId:       user.Id,
 | 
			
		||||
		Username:     user.Mobile,
 | 
			
		||||
		Username:     user.Username,
 | 
			
		||||
		LoginIp:      c.ClientIP(),
 | 
			
		||||
		LoginAddress: utils.Ip2Region(h.searcher, c.ClientIP()),
 | 
			
		||||
	})
 | 
			
		||||
@@ -190,24 +227,16 @@ func (h *UserHandler) Login(c *gin.Context) {
 | 
			
		||||
 | 
			
		||||
// Logout 注 销
 | 
			
		||||
func (h *UserHandler) Logout(c *gin.Context) {
 | 
			
		||||
	sessionId := c.GetHeader(types.ChatTokenHeader)
 | 
			
		||||
	key := h.GetUserKey(c)
 | 
			
		||||
	if _, err := h.redis.Del(c, key).Result(); err != nil {
 | 
			
		||||
		logger.Error("error with delete session: ", err)
 | 
			
		||||
	}
 | 
			
		||||
	// 删除 websocket 会话列表
 | 
			
		||||
	h.App.ChatSession.Delete(sessionId)
 | 
			
		||||
	// 关闭 socket 连接
 | 
			
		||||
	client := h.App.ChatClients.Get(sessionId)
 | 
			
		||||
	if client != nil {
 | 
			
		||||
		client.Close()
 | 
			
		||||
	}
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Session 获取/验证会话
 | 
			
		||||
func (h *UserHandler) Session(c *gin.Context) {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err == nil {
 | 
			
		||||
		var userVo vo.User
 | 
			
		||||
		err := utils.CopyObject(user, &userVo)
 | 
			
		||||
@@ -223,26 +252,23 @@ func (h *UserHandler) Session(c *gin.Context) {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type userProfile struct {
 | 
			
		||||
	Id          uint                 `json:"id"`
 | 
			
		||||
	Mobile      string               `json:"mobile"`
 | 
			
		||||
	Avatar      string               `json:"avatar"`
 | 
			
		||||
	ChatConfig  types.UserChatConfig `json:"chat_config"`
 | 
			
		||||
	Calls       int                  `json:"calls"`
 | 
			
		||||
	ImgCalls    int                  `json:"img_calls"`
 | 
			
		||||
	TotalTokens int64                `json:"total_tokens"`
 | 
			
		||||
	Tokens      int64                `json:"tokens"`
 | 
			
		||||
	ExpiredTime int64                `json:"expired_time"`
 | 
			
		||||
	Vip         bool                 `json:"vip"`
 | 
			
		||||
	Id          uint   `json:"id"`
 | 
			
		||||
	Nickname    string `json:"nickname"`
 | 
			
		||||
	Username    string `json:"username"`
 | 
			
		||||
	Avatar      string `json:"avatar"`
 | 
			
		||||
	Power       int    `json:"power"`
 | 
			
		||||
	ExpiredTime int64  `json:"expired_time"`
 | 
			
		||||
	Vip         bool   `json:"vip"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (h *UserHandler) Profile(c *gin.Context) {
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	h.db.First(&user, user.Id)
 | 
			
		||||
	h.DB.First(&user, user.Id)
 | 
			
		||||
	var profile userProfile
 | 
			
		||||
	err = utils.CopyObject(user, &profile)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
@@ -262,15 +288,15 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	h.db.First(&user, user.Id)
 | 
			
		||||
	h.DB.First(&user, user.Id)
 | 
			
		||||
	user.Avatar = data.Avatar
 | 
			
		||||
	user.ChatConfig = utils.JsonEncode(data.ChatConfig)
 | 
			
		||||
	res := h.db.Updates(&user)
 | 
			
		||||
	user.Nickname = data.Nickname
 | 
			
		||||
	res := h.DB.Updates(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新用户信息失败")
 | 
			
		||||
		return
 | 
			
		||||
@@ -279,8 +305,8 @@ func (h *UserHandler) ProfileUpdate(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Password 更新密码
 | 
			
		||||
func (h *UserHandler) Password(c *gin.Context) {
 | 
			
		||||
// UpdatePass 更新密码
 | 
			
		||||
func (h *UserHandler) UpdatePass(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		OldPass  string `json:"old_pass"`
 | 
			
		||||
		Password string `json:"password"`
 | 
			
		||||
@@ -295,21 +321,21 @@ func (h *UserHandler) Password(c *gin.Context) {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	password := utils.GenPassword(data.OldPass, user.Salt)
 | 
			
		||||
	logger.Info(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
 | 
			
		||||
	logger.Debugf(user.Salt, ",", user.Password, ",", password, ",", data.OldPass)
 | 
			
		||||
	if password != user.Password {
 | 
			
		||||
		resp.ERROR(c, "原密码错误")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	newPass := utils.GenPassword(data.Password, user.Salt)
 | 
			
		||||
	res := h.db.Model(&user).UpdateColumn("password", newPass)
 | 
			
		||||
	res := h.DB.Model(&user).UpdateColumn("password", newPass)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		logger.Error("更新数据库失败: ", res.Error)
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
@@ -319,46 +345,83 @@ func (h *UserHandler) Password(c *gin.Context) {
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BindMobile 绑定手机号
 | 
			
		||||
func (h *UserHandler) BindMobile(c *gin.Context) {
 | 
			
		||||
// ResetPass 重置密码
 | 
			
		||||
func (h *UserHandler) ResetPass(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Mobile string `json:"mobile"`
 | 
			
		||||
		Code   int    `json:"code"`
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Code     string `json:"code"`     // 验证码
 | 
			
		||||
		Password string `json:"password"` // 新密码
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查手机号是否被其他账号绑定
 | 
			
		||||
	var item model.User
 | 
			
		||||
	res := h.db.Where("mobile = ?", data.Mobile).First(&item)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		resp.ERROR(c, "该手机号已经被其他账号绑定")
 | 
			
		||||
	var user model.User
 | 
			
		||||
	res := h.DB.Where("username", data.Username).First(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "用户不存在!")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查验证码
 | 
			
		||||
	key := CodeStorePrefix + data.Mobile
 | 
			
		||||
	var code int
 | 
			
		||||
	err := h.leveldb.Get(key, &code)
 | 
			
		||||
	key := CodeStorePrefix + data.Username
 | 
			
		||||
	code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
	if err != nil || code != data.Code {
 | 
			
		||||
		resp.ERROR(c, "短信验证码错误")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := utils.GetLoginUser(c, h.db)
 | 
			
		||||
	password := utils.GenPassword(data.Password, user.Salt)
 | 
			
		||||
	user.Password = password
 | 
			
		||||
	res = h.DB.Updates(&user)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c)
 | 
			
		||||
	} else {
 | 
			
		||||
		h.redis.Del(c, key)
 | 
			
		||||
		resp.SUCCESS(c)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// BindUsername 重置账号
 | 
			
		||||
func (h *UserHandler) BindUsername(c *gin.Context) {
 | 
			
		||||
	var data struct {
 | 
			
		||||
		Username string `json:"username"`
 | 
			
		||||
		Code     string `json:"code"`
 | 
			
		||||
	}
 | 
			
		||||
	if err := c.ShouldBindJSON(&data); err != nil {
 | 
			
		||||
		resp.ERROR(c, types.InvalidArgs)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查验证码
 | 
			
		||||
	key := CodeStorePrefix + data.Username
 | 
			
		||||
	code, err := h.redis.Get(c, key).Result()
 | 
			
		||||
	if err != nil || code != data.Code {
 | 
			
		||||
		resp.ERROR(c, "验证码错误")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 检查手机号是否被其他账号绑定
 | 
			
		||||
	var item model.User
 | 
			
		||||
	res := h.DB.Where("username = ?", data.Username).First(&item)
 | 
			
		||||
	if res.Error == nil {
 | 
			
		||||
		resp.ERROR(c, "该账号已经被其他账号绑定")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	user, err := h.GetLoginUser(c)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		resp.NotAuth(c)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	res = h.db.Model(&user).UpdateColumn("mobile", data.Mobile)
 | 
			
		||||
	res = h.DB.Model(&user).UpdateColumn("username", data.Username)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		resp.ERROR(c, "更新数据库失败")
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	_ = h.leveldb.Delete(key) // 删除短信验证码
 | 
			
		||||
	_ = h.redis.Del(c, key) // 删除短信验证码
 | 
			
		||||
	resp.SUCCESS(c)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,5 +1,12 @@
 | 
			
		||||
package logger
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"go.uber.org/zap"
 | 
			
		||||
	"go.uber.org/zap/zapcore"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										269
									
								
								api/main.go
									
									
									
									
									
								
							
							
						
						
									
										269
									
								
								api/main.go
									
									
									
									
									
								
							@@ -1,23 +1,30 @@
 | 
			
		||||
package main
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/handler"
 | 
			
		||||
	"chatplus/handler/admin"
 | 
			
		||||
	"chatplus/handler/chatimpl"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service"
 | 
			
		||||
	"chatplus/service/fun"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/service/payment"
 | 
			
		||||
	"chatplus/service/sd"
 | 
			
		||||
	"chatplus/service/wx"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"context"
 | 
			
		||||
	"embed"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"geekai/core"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/handler"
 | 
			
		||||
	"geekai/handler/admin"
 | 
			
		||||
	"geekai/handler/chatimpl"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/dalle"
 | 
			
		||||
	"geekai/service/mj"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/payment"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/service/sms"
 | 
			
		||||
	"geekai/service/wx"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"io"
 | 
			
		||||
	"log"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -26,6 +33,8 @@ import (
 | 
			
		||||
	"syscall"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
 | 
			
		||||
	"github.com/lionsoul2014/ip2region/binding/golang/xdb"
 | 
			
		||||
	"go.uber.org/fx"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
@@ -42,34 +51,34 @@ type AppLifecycle struct {
 | 
			
		||||
 | 
			
		||||
// OnStart 应用程序启动时执行
 | 
			
		||||
func (l *AppLifecycle) OnStart(context.Context) error {
 | 
			
		||||
	log.Println("AppLifecycle OnStart")
 | 
			
		||||
	logger.Info("AppLifecycle OnStart")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// OnStop 应用程序停止时执行
 | 
			
		||||
func (l *AppLifecycle) OnStop(context.Context) error {
 | 
			
		||||
	log.Println("AppLifecycle OnStop")
 | 
			
		||||
	logger.Info("AppLifecycle OnStop")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewAppLifeCycle() *AppLifecycle {
 | 
			
		||||
	return &AppLifecycle{}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func main() {
 | 
			
		||||
	configFile := os.Getenv("CONFIG_FILE")
 | 
			
		||||
	if configFile == "" {
 | 
			
		||||
		configFile = "config.toml"
 | 
			
		||||
	}
 | 
			
		||||
	var debug bool
 | 
			
		||||
	debugEnv := os.Getenv("DEBUG")
 | 
			
		||||
	if debugEnv == "" {
 | 
			
		||||
		debug = true
 | 
			
		||||
	} else {
 | 
			
		||||
		debug, _ = strconv.ParseBool(os.Getenv("DEBUG"))
 | 
			
		||||
	}
 | 
			
		||||
	debug, _ := strconv.ParseBool(os.Getenv("APP_DEBUG"))
 | 
			
		||||
	logger.Info("Loading config file: ", configFile)
 | 
			
		||||
	defer func() {
 | 
			
		||||
		if err := recover(); err != nil {
 | 
			
		||||
			logger.Error("Panic Error:", err)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
	if !debug {
 | 
			
		||||
		defer func() {
 | 
			
		||||
			if err := recover(); err != nil {
 | 
			
		||||
				logger.Error("Panic Error:", err)
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	app := fx.New(
 | 
			
		||||
		// 初始化配置应用配置
 | 
			
		||||
@@ -94,8 +103,8 @@ func main() {
 | 
			
		||||
		// 初始化数据库
 | 
			
		||||
		fx.Provide(store.NewGormConfig),
 | 
			
		||||
		fx.Provide(store.NewMysql),
 | 
			
		||||
		fx.Provide(store.NewLevelDB),
 | 
			
		||||
		fx.Provide(store.NewRedisClient),
 | 
			
		||||
		fx.Provide(store.NewLevelDB),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(func() embed.FS {
 | 
			
		||||
			return xdbFS
 | 
			
		||||
@@ -115,9 +124,6 @@ func main() {
 | 
			
		||||
			return xdb.NewWithBuffer(cBuff)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 创建函数
 | 
			
		||||
		fx.Provide(fun.NewFunctions),
 | 
			
		||||
 | 
			
		||||
		// 创建控制器
 | 
			
		||||
		fx.Provide(handler.NewChatRoleHandler),
 | 
			
		||||
		fx.Provide(handler.NewUserHandler),
 | 
			
		||||
@@ -132,6 +138,8 @@ func main() {
 | 
			
		||||
		fx.Provide(handler.NewPaymentHandler),
 | 
			
		||||
		fx.Provide(handler.NewOrderHandler),
 | 
			
		||||
		fx.Provide(handler.NewProductHandler),
 | 
			
		||||
		fx.Provide(handler.NewConfigHandler),
 | 
			
		||||
		fx.Provide(handler.NewPowerLogHandler),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(admin.NewConfigHandler),
 | 
			
		||||
		fx.Provide(admin.NewAdminHandler),
 | 
			
		||||
@@ -143,14 +151,26 @@ func main() {
 | 
			
		||||
		fx.Provide(admin.NewChatModelHandler),
 | 
			
		||||
		fx.Provide(admin.NewProductHandler),
 | 
			
		||||
		fx.Provide(admin.NewOrderHandler),
 | 
			
		||||
		fx.Provide(admin.NewChatHandler),
 | 
			
		||||
		fx.Provide(admin.NewPowerLogHandler),
 | 
			
		||||
 | 
			
		||||
		// 创建服务
 | 
			
		||||
		fx.Provide(service.NewAliYunSmsService),
 | 
			
		||||
		fx.Provide(sms.NewSendServiceManager),
 | 
			
		||||
		fx.Provide(func(config *types.AppConfig) *service.CaptchaService {
 | 
			
		||||
			return service.NewCaptchaService(config.ApiConfig)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(oss.NewUploaderManager),
 | 
			
		||||
		fx.Provide(mj.NewService),
 | 
			
		||||
		fx.Provide(dalle.NewService),
 | 
			
		||||
		fx.Invoke(func(service *dalle.Service) {
 | 
			
		||||
			service.Run()
 | 
			
		||||
			service.CheckTaskNotify()
 | 
			
		||||
			service.DownloadImages()
 | 
			
		||||
			service.CheckTaskStatus()
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 邮件服务
 | 
			
		||||
		fx.Provide(service.NewSmtpService),
 | 
			
		||||
 | 
			
		||||
		// 微信机器人服务
 | 
			
		||||
		fx.Provide(wx.NewWeChatBot),
 | 
			
		||||
@@ -163,36 +183,30 @@ func main() {
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// MidJourney 机器人
 | 
			
		||||
		fx.Provide(mj.NewBot),
 | 
			
		||||
		fx.Provide(mj.NewClient),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, bot *mj.Bot) {
 | 
			
		||||
			if config.MjConfig.Enabled {
 | 
			
		||||
				err := bot.Run()
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Fatal("MidJourney 服务启动失败:", err)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, mjService *mj.Service) {
 | 
			
		||||
			if config.MjConfig.Enabled {
 | 
			
		||||
				go func() {
 | 
			
		||||
					mjService.Run()
 | 
			
		||||
				}()
 | 
			
		||||
		// MidJourney service pool
 | 
			
		||||
		fx.Provide(mj.NewServicePool),
 | 
			
		||||
		fx.Invoke(func(pool *mj.ServicePool, config *types.AppConfig) {
 | 
			
		||||
			pool.InitServices(config.MjPlusConfigs, config.MjProxyConfigs)
 | 
			
		||||
			if pool.HasAvailableService() {
 | 
			
		||||
				pool.DownloadImages()
 | 
			
		||||
				pool.CheckTaskNotify()
 | 
			
		||||
				pool.SyncTaskProgress()
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// Stable Diffusion 机器人
 | 
			
		||||
		fx.Provide(sd.NewService),
 | 
			
		||||
		fx.Invoke(func(config *types.AppConfig, service *sd.Service) {
 | 
			
		||||
			if config.SdConfig.Enabled {
 | 
			
		||||
				go func() {
 | 
			
		||||
					service.Run()
 | 
			
		||||
				}()
 | 
			
		||||
		fx.Provide(sd.NewServicePool),
 | 
			
		||||
		fx.Invoke(func(pool *sd.ServicePool, config *types.AppConfig) {
 | 
			
		||||
			pool.InitServices(config.SdConfigs)
 | 
			
		||||
			if pool.HasAvailableService() {
 | 
			
		||||
				pool.CheckTaskNotify()
 | 
			
		||||
				pool.CheckTaskStatus()
 | 
			
		||||
			}
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(payment.NewAlipayService),
 | 
			
		||||
		fx.Provide(payment.NewHuPiPay),
 | 
			
		||||
		fx.Provide(payment.NewPayJS),
 | 
			
		||||
		fx.Provide(service.NewSnowflake),
 | 
			
		||||
		fx.Provide(service.NewXXLJobExecutor),
 | 
			
		||||
		fx.Invoke(func(exec *service.XXLJobExecutor, config *types.AppConfig) {
 | 
			
		||||
@@ -217,8 +231,9 @@ func main() {
 | 
			
		||||
			group.GET("session", h.Session)
 | 
			
		||||
			group.GET("profile", h.Profile)
 | 
			
		||||
			group.POST("profile/update", h.ProfileUpdate)
 | 
			
		||||
			group.POST("password", h.Password)
 | 
			
		||||
			group.POST("bind/mobile", h.BindMobile)
 | 
			
		||||
			group.POST("password", h.UpdatePass)
 | 
			
		||||
			group.POST("bind/username", h.BindUsername)
 | 
			
		||||
			group.POST("resetPass", h.ResetPass)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *chatimpl.ChatHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/chat/")
 | 
			
		||||
@@ -234,16 +249,19 @@ func main() {
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.UploadHandler) {
 | 
			
		||||
			s.Engine.POST("/api/upload", h.Upload)
 | 
			
		||||
			s.Engine.GET("/api/upload/list", h.List)
 | 
			
		||||
			s.Engine.GET("/api/upload/remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.SmsHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/sms/")
 | 
			
		||||
			group.GET("status", h.Status)
 | 
			
		||||
			group.POST("code", h.SendCode)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.CaptchaHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/captcha/")
 | 
			
		||||
			group.GET("get", h.Get)
 | 
			
		||||
			group.POST("check", h.Check)
 | 
			
		||||
			group.GET("slide/get", h.SlideGet)
 | 
			
		||||
			group.POST("slide/check", h.SlideCheck)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.RewardHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/reward/")
 | 
			
		||||
@@ -251,36 +269,51 @@ func main() {
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.MidJourneyHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/mj/")
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.POST("image", h.Image)
 | 
			
		||||
			group.POST("upscale", h.Upscale)
 | 
			
		||||
			group.POST("variation", h.Variation)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.GET("imgWall", h.ImgWall)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
			group.POST("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.SdJobHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/sd")
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.POST("image", h.Image)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.GET("imgWall", h.ImgWall)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
			group.POST("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.ConfigHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/config/")
 | 
			
		||||
			group.GET("get", h.Get)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 管理后台控制器
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ConfigHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/config/")
 | 
			
		||||
			group.POST("update", h.Update)
 | 
			
		||||
			group.GET("get", h.Get)
 | 
			
		||||
			group := s.Engine.Group("/api/admin/")
 | 
			
		||||
			group.POST("config/update", h.Update)
 | 
			
		||||
			group.GET("config/get", h.Get)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ManagerHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/")
 | 
			
		||||
			group.POST("login", h.Login)
 | 
			
		||||
			group.GET("logout", h.Logout)
 | 
			
		||||
			group.GET("session", h.Session)
 | 
			
		||||
			group.GET("migrate", h.Migrate)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.POST("save", h.Save)
 | 
			
		||||
			group.POST("enable", h.Enable)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.POST("resetPass", h.ResetPass)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ApiKeyHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/apikey/")
 | 
			
		||||
			group.POST("save", h.Save)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.POST("set", h.Set)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.UserHandler) {
 | 
			
		||||
@@ -296,11 +329,13 @@ func main() {
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.POST("save", h.Save)
 | 
			
		||||
			group.POST("sort", h.Sort)
 | 
			
		||||
			group.POST("set", h.Set)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.RewardHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/reward/")
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.DashboardHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/dashboard/")
 | 
			
		||||
@@ -314,16 +349,20 @@ func main() {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/model/")
 | 
			
		||||
			group.POST("save", h.Save)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.POST("enable", h.Enable)
 | 
			
		||||
			group.POST("set", h.Set)
 | 
			
		||||
			group.POST("sort", h.Sort)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.PaymentHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/payment/")
 | 
			
		||||
			group.GET("alipay", h.Alipay)
 | 
			
		||||
			group.GET("doPay", h.DoPay)
 | 
			
		||||
			group.GET("payWays", h.GetPayWays)
 | 
			
		||||
			group.POST("query", h.OrderQuery)
 | 
			
		||||
			group.POST("alipay/qrcode", h.AlipayQrcode)
 | 
			
		||||
			group.POST("qrcode", h.PayQrcode)
 | 
			
		||||
			group.POST("mobile", h.Mobile)
 | 
			
		||||
			group.POST("alipay/notify", h.AlipayNotify)
 | 
			
		||||
			group.POST("hupipay/notify", h.HuPiPayNotify)
 | 
			
		||||
			group.POST("payjs/notify", h.PayJsNotify)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ProductHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/product/")
 | 
			
		||||
@@ -347,13 +386,97 @@ func main() {
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
 | 
			
		||||
			err := s.Run(db)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				log.Fatal(err)
 | 
			
		||||
			}
 | 
			
		||||
		fx.Provide(handler.NewInviteHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.InviteHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/invite/")
 | 
			
		||||
			group.GET("code", h.Code)
 | 
			
		||||
			group.POST("list", h.List)
 | 
			
		||||
			group.GET("hits", h.Hits)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(admin.NewFunctionHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.FunctionHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/function/")
 | 
			
		||||
			group.POST("save", h.Save)
 | 
			
		||||
			group.POST("set", h.Set)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
			group.GET("token", h.GenToken)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		// 验证码
 | 
			
		||||
		fx.Provide(admin.NewCaptchaHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.CaptchaHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/login/")
 | 
			
		||||
			group.GET("captcha", h.GetCaptcha)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(admin.NewUploadHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.UploadHandler) {
 | 
			
		||||
			s.Engine.POST("/api/admin/upload", h.Upload)
 | 
			
		||||
		}),
 | 
			
		||||
 | 
			
		||||
		fx.Provide(handler.NewFunctionHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.FunctionHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/function/")
 | 
			
		||||
			group.POST("weibo", h.WeiBo)
 | 
			
		||||
			group.POST("zaobao", h.ZaoBao)
 | 
			
		||||
			group.POST("dalle3", h.Dall3)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.ChatHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/chat/")
 | 
			
		||||
			group.POST("list", h.List)
 | 
			
		||||
			group.POST("message", h.Messages)
 | 
			
		||||
			group.GET("history", h.History)
 | 
			
		||||
			group.GET("remove", h.RemoveChat)
 | 
			
		||||
			group.GET("message/remove", h.RemoveMessage)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.PowerLogHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/powerLog/")
 | 
			
		||||
			group.POST("list", h.List)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.PowerLogHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/powerLog/")
 | 
			
		||||
			group.POST("list", h.List)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(admin.NewMenuHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *admin.MenuHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/admin/menu/")
 | 
			
		||||
			group.POST("save", h.Save)
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
			group.POST("enable", h.Enable)
 | 
			
		||||
			group.POST("sort", h.Sort)
 | 
			
		||||
			group.GET("remove", h.Remove)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewMenuHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.MenuHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/menu/")
 | 
			
		||||
			group.GET("list", h.List)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewMarkMapHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.MarkMapHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/markMap/")
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(handler.NewDallJobHandler),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, h *handler.DallJobHandler) {
 | 
			
		||||
			group := s.Engine.Group("/api/dall")
 | 
			
		||||
			group.Any("client", h.Client)
 | 
			
		||||
			group.POST("image", h.Image)
 | 
			
		||||
			group.GET("jobs", h.JobList)
 | 
			
		||||
			group.GET("imgWall", h.ImgWall)
 | 
			
		||||
			group.POST("remove", h.Remove)
 | 
			
		||||
			group.POST("publish", h.Publish)
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Invoke(func(s *core.AppServer, db *gorm.DB) {
 | 
			
		||||
			go func() {
 | 
			
		||||
				err := s.Run(db)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					log.Fatal(err)
 | 
			
		||||
				}
 | 
			
		||||
			}()
 | 
			
		||||
		}),
 | 
			
		||||
		fx.Provide(NewAppLifeCycle),
 | 
			
		||||
		// 注册生命周期回调函数
 | 
			
		||||
		fx.Invoke(func(lifecycle fx.Lifecycle, lc *AppLifecycle) {
 | 
			
		||||
			lifecycle.Append(fx.Hook{
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										38
									
								
								api/res/certs/alipay/alipayPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										38
									
								
								api/res/certs/alipay/alipayPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,38 @@
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIIDszCCApugAwIBAgIQICMRB0rBU2/rZJbfJGMYIzANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
 | 
			
		||||
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
 | 
			
		||||
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
 | 
			
		||||
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDYzNTQxWhcNMjQxMTA2MDYzNTQxWjCB
 | 
			
		||||
hDELMAkGA1UEBhMCQ04xHzAdBgNVBAoMFm1ib25meTkwMTVAc2FuZGJveC5jb20xDzANBgNVBAsM
 | 
			
		||||
BkFsaXBheTFDMEEGA1UEAww65pSv5LuY5a6dKOS4reWbvSnnvZHnu5zmioDmnK/mnInpmZDlhazl
 | 
			
		||||
j7gtMjA4ODcyMTAyMDc1MDU4MTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKsoKcw5
 | 
			
		||||
sxaiyV7mpWzDtnQ1K518eQLP0+dJlZAf06aBep/Aj9DIqrba/k7DHt8dKQvILMLAMpN1+2IRxbaO
 | 
			
		||||
yxMa/laj3lZ1eHrB6F077O3D62oHcE3noZtXL0N1zZAxpmkNmYIHeLZS2oLMS4ANu47O/wpDC7BV
 | 
			
		||||
HjdpZugtdPJ4mxdCpM9GDdLs7W4s5QI4PUPK4skFNMFoKI+0cYP/9ju87UP//IHC/K510GWNl+Gn
 | 
			
		||||
Cvgag3AmiIB0utJNsGhxm6zT1T9tUWjW9iz/BxBKiPatsCX9VpPQzGnW7ZonRQtiZSokIlP2IPvl
 | 
			
		||||
H5DcwpWUz3/LUY0SmKxnKOEYeOOqCW8CAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
 | 
			
		||||
DQEBCwUAA4IBAQAtgxF2EzjOndEFxBUD9tFwcSt6XKGggOp52oft1pvynPg4ALTLafOtfEPDrFBH
 | 
			
		||||
PwpYrSu9s9C8NJtaA2HrlCfBjIuwEFTXiN+HPvS0SwSPKt9AXEiTcOF8vDcGamEen8QI4fo5Jia7
 | 
			
		||||
2VRKkerkww5/+FzSaVO7ZUKuL80M1QJStmAZc8kPPwdYOTTW2bGf8BcmSDL6SPElBkt7tCCRd4sn
 | 
			
		||||
+jq4cZ0yb2i77rBZCwHcTvfTqIBblPwLv4uGvg3+83BxIB5w6Kqp06bKEAPmobFY5IVHa+ON0/qi
 | 
			
		||||
BXxXr+WQ3piKRVQEN64+PTAjSc67Ix1umvpLl3Ko6Ry7NJmpDcUn
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIIDszCCApugAwIBAgIQIBkIGbgVxq210KxLJ+YA/TANBgkqhkiG9w0BAQsFADCBhDELMAkGA1UE
 | 
			
		||||
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxJTAjBgNVBAsMHENlcnRpZmljYXRpb24gQXV0
 | 
			
		||||
aG9yaXR5IHRlc3QxNjA0BgNVBAMMLUFudCBGaW5hbmNpYWwgQ2VydGlmaWNhdGlvbiBBdXRob3Jp
 | 
			
		||||
dHkgUjEgdGVzdDAeFw0xOTA4MTkxMTE2MDBaFw0yNDA4MDExMTE2MDBaMIGRMQswCQYDVQQGEwJD
 | 
			
		||||
TjEbMBkGA1UECgwSQW50IEZpbmFuY2lhbCB0ZXN0MSUwIwYDVQQLDBxDZXJ0aWZpY2F0aW9uIEF1
 | 
			
		||||
dGhvcml0eSB0ZXN0MT4wPAYDVQQDDDVBbnQgRmluYW5jaWFsIENlcnRpZmljYXRpb24gQXV0aG9y
 | 
			
		||||
aXR5IENsYXNzIDIgUjEgdGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMh4FKYO
 | 
			
		||||
ZyRQHD6eFbPKZeSAnrfjfU7xmS9Yoozuu+iuqZlb6Z0SPLUqqTZAFZejOcmr07ln/pwZxluqplxC
 | 
			
		||||
5+B48End4nclDMlT5HPrDr3W0frs6Xsa2ZNcyil/iKNB5MbGll8LRAxntsKvZZj6vUTMb705gYgm
 | 
			
		||||
VUMILwi/ZxKTQqBtkT/kQQ5y6nOZsj7XI5rYdz6qqOROrpvS/d7iypdHOMIM9Iz9DlL1mrCykbBi
 | 
			
		||||
t25y+gTeXmuisHUwqaRpwtCGK4BayCqxRGbNipe6W73EK9lBrrzNtTr9NaysesT/v+l25JHCL9tG
 | 
			
		||||
wpNr1oWFzk4IHVOg0ORiQ6SUgxZUTYcCAwEAAaMSMBAwDgYDVR0PAQH/BAQDAgTwMA0GCSqGSIb3
 | 
			
		||||
DQEBCwUAA4IBAQBWThEoIaQoBX2YeRY/I8gu6TYnFXtyuCljANnXnM38ft+ikhE5mMNgKmJYLHvT
 | 
			
		||||
yWWWgwHoSAWEuml7EGbE/2AK2h3k0MdfiWLzdmpPCRG/RJHk6UB1pMHPilI+c0MVu16OPpKbg5Vf
 | 
			
		||||
LTv7dsAB40AzKsvyYw88/Ezi1osTXo6QQwda7uefvudirtb8FcQM9R66cJxl3kt1FXbpYwheIm/p
 | 
			
		||||
j1mq64swCoIYu4NrsUYtn6CV542DTQMI5QdXkn+PzUUly8F6kDp+KpMNd0avfWNL5+O++z+F5Szy
 | 
			
		||||
1CPta1D7EQ/eYmMP+mOQ35oifWIoFCpN6qQVBS/Hob1J/UUyg7BW
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
							
								
								
									
										88
									
								
								api/res/certs/alipay/alipayRootCert.crt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										88
									
								
								api/res/certs/alipay/alipayRootCert.crt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,88 @@
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIIBszCCAVegAwIBAgIIaeL+wBcKxnswDAYIKoEcz1UBg3UFADAuMQswCQYDVQQG
 | 
			
		||||
EwJDTjEOMAwGA1UECgwFTlJDQUMxDzANBgNVBAMMBlJPT1RDQTAeFw0xMjA3MTQw
 | 
			
		||||
MzExNTlaFw00MjA3MDcwMzExNTlaMC4xCzAJBgNVBAYTAkNOMQ4wDAYDVQQKDAVO
 | 
			
		||||
UkNBQzEPMA0GA1UEAwwGUk9PVENBMFkwEwYHKoZIzj0CAQYIKoEcz1UBgi0DQgAE
 | 
			
		||||
MPCca6pmgcchsTf2UnBeL9rtp4nw+itk1Kzrmbnqo05lUwkwlWK+4OIrtFdAqnRT
 | 
			
		||||
V7Q9v1htkv42TsIutzd126NdMFswHwYDVR0jBBgwFoAUTDKxl9kzG8SmBcHG5Yti
 | 
			
		||||
W/CXdlgwDAYDVR0TBAUwAwEB/zALBgNVHQ8EBAMCAQYwHQYDVR0OBBYEFEwysZfZ
 | 
			
		||||
MxvEpgXBxuWLYlvwl3ZYMAwGCCqBHM9VAYN1BQADSAAwRQIgG1bSLeOXp3oB8H7b
 | 
			
		||||
53W+CKOPl2PknmWEq/lMhtn25HkCIQDaHDgWxWFtnCrBjH16/W3Ezn7/U/Vjo5xI
 | 
			
		||||
pDoiVhsLwg==
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIIF0zCCA7ugAwIBAgIIH8+hjWpIDREwDQYJKoZIhvcNAQELBQAwejELMAkGA1UE
 | 
			
		||||
BhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNVBAsMF0NlcnRpZmlj
 | 
			
		||||
YXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5jaWFsIENlcnRpZmlj
 | 
			
		||||
YXRpb24gQXV0aG9yaXR5IFIxMB4XDTE4MDMyMTEzNDg0MFoXDTM4MDIyODEzNDg0
 | 
			
		||||
MFowejELMAkGA1UEBhMCQ04xFjAUBgNVBAoMDUFudCBGaW5hbmNpYWwxIDAeBgNV
 | 
			
		||||
BAsMF0NlcnRpZmljYXRpb24gQXV0aG9yaXR5MTEwLwYDVQQDDChBbnQgRmluYW5j
 | 
			
		||||
aWFsIENlcnRpZmljYXRpb24gQXV0aG9yaXR5IFIxMIICIjANBgkqhkiG9w0BAQEF
 | 
			
		||||
AAOCAg8AMIICCgKCAgEAtytTRcBNuur5h8xuxnlKJetT65cHGemGi8oD+beHFPTk
 | 
			
		||||
rUTlFt9Xn7fAVGo6QSsPb9uGLpUFGEdGmbsQ2q9cV4P89qkH04VzIPwT7AywJdt2
 | 
			
		||||
xAvMs+MgHFJzOYfL1QkdOOVO7NwKxH8IvlQgFabWomWk2Ei9WfUyxFjVO1LVh0Bp
 | 
			
		||||
dRBeWLMkdudx0tl3+21t1apnReFNQ5nfX29xeSxIhesaMHDZFViO/DXDNW2BcTs6
 | 
			
		||||
vSWKyJ4YIIIzStumD8K1xMsoaZBMDxg4itjWFaKRgNuPiIn4kjDY3kC66Sl/6yTl
 | 
			
		||||
YUz8AybbEsICZzssdZh7jcNb1VRfk79lgAprm/Ktl+mgrU1gaMGP1OE25JCbqli1
 | 
			
		||||
Pbw/BpPynyP9+XulE+2mxFwTYhKAwpDIDKuYsFUXuo8t261pCovI1CXFzAQM2w7H
 | 
			
		||||
DtA2nOXSW6q0jGDJ5+WauH+K8ZSvA6x4sFo4u0KNCx0ROTBpLif6GTngqo3sj+98
 | 
			
		||||
SZiMNLFMQoQkjkdN5Q5g9N6CFZPVZ6QpO0JcIc7S1le/g9z5iBKnifrKxy0TQjtG
 | 
			
		||||
PsDwc8ubPnRm/F82RReCoyNyx63indpgFfhN7+KxUIQ9cOwwTvemmor0A+ZQamRe
 | 
			
		||||
9LMuiEfEaWUDK+6O0Gl8lO571uI5onYdN1VIgOmwFbe+D8TcuzVjIZ/zvHrAGUcC
 | 
			
		||||
AwEAAaNdMFswCwYDVR0PBAQDAgEGMAwGA1UdEwQFMAMBAf8wHQYDVR0OBBYEFF90
 | 
			
		||||
tATATwda6uWx2yKjh0GynOEBMB8GA1UdIwQYMBaAFF90tATATwda6uWx2yKjh0Gy
 | 
			
		||||
nOEBMA0GCSqGSIb3DQEBCwUAA4ICAQCVYaOtqOLIpsrEikE5lb+UARNSFJg6tpkf
 | 
			
		||||
tJ2U8QF/DejemEHx5IClQu6ajxjtu0Aie4/3UnIXop8nH/Q57l+Wyt9T7N2WPiNq
 | 
			
		||||
JSlYKYbJpPF8LXbuKYG3BTFTdOVFIeRe2NUyYh/xs6bXGr4WKTXb3qBmzR02FSy3
 | 
			
		||||
IODQw5Q6zpXj8prYqFHYsOvGCEc1CwJaSaYwRhTkFedJUxiyhyB5GQwoFfExCVHW
 | 
			
		||||
05ZFCAVYFldCJvUzfzrWubN6wX0DD2dwultgmldOn/W/n8at52mpPNvIdbZb2F41
 | 
			
		||||
T0YZeoWnCJrYXjq/32oc1cmifIHqySnyMnavi75DxPCdZsCOpSAT4j4lAQRGsfgI
 | 
			
		||||
kkLPGQieMfNNkMCKh7qjwdXAVtdqhf0RVtFILH3OyEodlk1HYXqX5iE5wlaKzDop
 | 
			
		||||
PKwf2Q3BErq1xChYGGVS+dEvyXc/2nIBlt7uLWKp4XFjqekKbaGaLJdjYP5b2s7N
 | 
			
		||||
1dM0MXQ/f8XoXKBkJNzEiM3hfsU6DOREgMc1DIsFKxfuMwX3EkVQM1If8ghb6x5Y
 | 
			
		||||
jXayv+NLbidOSzk4vl5QwngO/JYFMkoc6i9LNwEaEtR9PhnrdubxmrtM+RjfBm02
 | 
			
		||||
77q3dSWFESFQ4QxYWew4pHE0DpWbWy/iMIKQ6UZ5RLvB8GEcgt8ON7BBJeMc+Dyi
 | 
			
		||||
kT9qhqn+lw==
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIICiDCCAgygAwIBAgIIQX76UsB/30owDAYIKoZIzj0EAwMFADB6MQswCQYDVQQG
 | 
			
		||||
EwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UECwwXQ2VydGlmaWNh
 | 
			
		||||
dGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNpYWwgQ2VydGlmaWNh
 | 
			
		||||
dGlvbiBBdXRob3JpdHkgRTEwHhcNMTkwNDI4MTYyMDQ0WhcNNDkwNDIwMTYyMDQ0
 | 
			
		||||
WjB6MQswCQYDVQQGEwJDTjEWMBQGA1UECgwNQW50IEZpbmFuY2lhbDEgMB4GA1UE
 | 
			
		||||
CwwXQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkxMTAvBgNVBAMMKEFudCBGaW5hbmNp
 | 
			
		||||
YWwgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkgRTEwdjAQBgcqhkjOPQIBBgUrgQQA
 | 
			
		||||
IgNiAASCCRa94QI0vR5Up9Yr9HEupz6hSoyjySYqo7v837KnmjveUIUNiuC9pWAU
 | 
			
		||||
WP3jwLX3HkzeiNdeg22a0IZPoSUCpasufiLAnfXh6NInLiWBrjLJXDSGaY7vaokt
 | 
			
		||||
rpZvAdmjXTBbMAsGA1UdDwQEAwIBBjAMBgNVHRMEBTADAQH/MB0GA1UdDgQWBBRZ
 | 
			
		||||
4ZTgDpksHL2qcpkFkxD2zVd16TAfBgNVHSMEGDAWgBRZ4ZTgDpksHL2qcpkFkxD2
 | 
			
		||||
zVd16TAMBggqhkjOPQQDAwUAA2gAMGUCMQD4IoqT2hTUn0jt7oXLdMJ8q4vLp6sg
 | 
			
		||||
wHfPiOr9gxreb+e6Oidwd2LDnC4OUqCWiF8CMAzwKs4SnDJYcMLf2vpkbuVE4dTH
 | 
			
		||||
Rglz+HGcTLWsFs4KxLsq7MuU+vJTBUeDJeDjdA==
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIIDxTCCAq2gAwIBAgIUEMdk6dVgOEIS2cCP0Q43P90Ps5YwDQYJKoZIhvcNAQEF
 | 
			
		||||
BQAwajELMAkGA1UEBhMCQ04xEzARBgNVBAoMCmlUcnVzQ2hpbmExHDAaBgNVBAsM
 | 
			
		||||
E0NoaW5hIFRydXN0IE5ldHdvcmsxKDAmBgNVBAMMH2lUcnVzQ2hpbmEgQ2xhc3Mg
 | 
			
		||||
MiBSb290IENBIC0gRzMwHhcNMTMwNDE4MDkzNjU2WhcNMzMwNDE4MDkzNjU2WjBq
 | 
			
		||||
MQswCQYDVQQGEwJDTjETMBEGA1UECgwKaVRydXNDaGluYTEcMBoGA1UECwwTQ2hp
 | 
			
		||||
bmEgVHJ1c3QgTmV0d29yazEoMCYGA1UEAwwfaVRydXNDaGluYSBDbGFzcyAyIFJv
 | 
			
		||||
b3QgQ0EgLSBHMzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOPPShpV
 | 
			
		||||
nJbMqqCw6Bz1kehnoPst9pkr0V9idOwU2oyS47/HjJXk9Rd5a9xfwkPO88trUpz5
 | 
			
		||||
4GmmwspDXjVFu9L0eFaRuH3KMha1Ak01citbF7cQLJlS7XI+tpkTGHEY5pt3EsQg
 | 
			
		||||
wykfZl/A1jrnSkspMS997r2Gim54cwz+mTMgDRhZsKK/lbOeBPpWtcFizjXYCqhw
 | 
			
		||||
WktvQfZBYi6o4sHCshnOswi4yV1p+LuFcQ2ciYdWvULh1eZhLxHbGXyznYHi0dGN
 | 
			
		||||
z+I9H8aXxqAQfHVhbdHNzi77hCxFjOy+hHrGsyzjrd2swVQ2iUWP8BfEQqGLqM1g
 | 
			
		||||
KgWKYfcTGdbPB1MCAwEAAaNjMGEwHQYDVR0OBBYEFG/oAMxTVe7y0+408CTAK8hA
 | 
			
		||||
uTyRMB8GA1UdIwQYMBaAFG/oAMxTVe7y0+408CTAK8hAuTyRMA8GA1UdEwEB/wQF
 | 
			
		||||
MAMBAf8wDgYDVR0PAQH/BAQDAgEGMA0GCSqGSIb3DQEBBQUAA4IBAQBLnUTfW7hp
 | 
			
		||||
emMbuUGCk7RBswzOT83bDM6824EkUnf+X0iKS95SUNGeeSWK2o/3ALJo5hi7GZr3
 | 
			
		||||
U8eLaWAcYizfO99UXMRBPw5PRR+gXGEronGUugLpxsjuynoLQu8GQAeysSXKbN1I
 | 
			
		||||
UugDo9u8igJORYA+5ms0s5sCUySqbQ2R5z/GoceyI9LdxIVa1RjVX8pYOj8JFwtn
 | 
			
		||||
DJN3ftSFvNMYwRuILKuqUYSHc2GPYiHVflDh5nDymCMOQFcFG3WsEuB+EYQPFgIU
 | 
			
		||||
1DHmdZcz7Llx8UOZXX2JupWCYzK1XhJb+r4hK5ncf/w8qGtYlmyJpxk3hr1TfUJX
 | 
			
		||||
Yf4Zr0fJsGuv
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
							
								
								
									
										19
									
								
								api/res/certs/alipay/appPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										19
									
								
								api/res/certs/alipay/appPublicCert.crt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,19 @@
 | 
			
		||||
-----BEGIN CERTIFICATE-----
 | 
			
		||||
MIIDmTCCAoGgAwIBAgIQICMRB2LW76yahgdg3IFNPDANBgkqhkiG9w0BAQsFADCBkTELMAkGA1UE
 | 
			
		||||
BhMCQ04xGzAZBgNVBAoMEkFudCBGaW5hbmNpYWwgdGVzdDElMCMGA1UECwwcQ2VydGlmaWNhdGlv
 | 
			
		||||
biBBdXRob3JpdHkgdGVzdDE+MDwGA1UEAww1QW50IEZpbmFuY2lhbCBDZXJ0aWZpY2F0aW9uIEF1
 | 
			
		||||
dGhvcml0eSBDbGFzcyAyIFIxIHRlc3QwHhcNMjMxMTA3MDU0NjE5WhcNMjQxMTExMDU0NjE5WjBr
 | 
			
		||||
MQswCQYDVQQGEwJDTjEfMB0GA1UECgwWbWJvbmZ5OTAxNUBzYW5kYm94LmNvbTEPMA0GA1UECwwG
 | 
			
		||||
QWxpcGF5MSowKAYDVQQDDCEyMDg4NzIxMDIwNzUwNTgxLTkwMjEwMDAxMzE2NTgwMjMwggEiMA0G
 | 
			
		||||
CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCxihQPf1Q+g9ArgM46shVqL5sbRha/df95D1PsWyEq
 | 
			
		||||
ANmWmG4zZ+ksYDVQrc4KzhSRoi56sm/7TDFYTmM6bW99e/nKW58WxyZB4ie5qA3F4n17psPyDqb8
 | 
			
		||||
IokcQmCphSFDaXQD6AoXoLNtTM0vAI2cWxAgebZ/vsrdj5Ntjt+Rp3NYMCk1i5xovHcfILzLEGbX
 | 
			
		||||
QXoT9fo5AhHotTWa6xHVLPUGY9qwLzQxHzBmvy5ZMfnOfJkm/mDisTSqAUB59F3dzU/1ARVkEZ1w
 | 
			
		||||
Mgb4XohWBw6iurQfbMnH2mIomAAwwZVFv+sXDbL9yMbSMo/SjVsTQprn0Q0EnwLo7nmmOM6HAgMB
 | 
			
		||||
AAGjEjAQMA4GA1UdDwEB/wQEAwIE8DANBgkqhkiG9w0BAQsFAAOCAQEAn3Y4/C1h9R6ONsBqX3/q
 | 
			
		||||
XfHX7yX1FM0Y1x48X3/Yxk6HivAkTukhhhVYVKJsbrbzRqHDp9vhAP/FR6o6pAevaYMmLov0VMXU
 | 
			
		||||
7oAuetgkaYEYkDuNen5/Hpdhqi2vTtdT+q9w8zHJd6MDQ0aoHgIxpLKw5vof2R1N4fwSgNXMiXE5
 | 
			
		||||
kmllKQMem/+on2p+Sj80/2asxryHIGlH87qPzkffv+kIOkZthbTApTFLLjdVri2QHGe8/cc4xy01
 | 
			
		||||
/9iR3IUzNahotT41lJ4bMevBY7XMAS3n5ekyABN/9ZRJqhWdXgmFCRN/u56qd6lDgu7R2M2QUoyc
 | 
			
		||||
LuW5DfgRItKlmUB7sw==
 | 
			
		||||
-----END CERTIFICATE-----
 | 
			
		||||
							
								
								
									
										1
									
								
								api/res/certs/alipay/privateKey.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										1
									
								
								api/res/certs/alipay/privateKey.txt
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1 @@
 | 
			
		||||
MIIEpQIBAAKCAQEAsYoUD39UPoPQK4DOOrIVai+bG0YWv3X/eQ9T7FshKgDZlphuM2fpLGA1UK3OCs4UkaIuerJv+0wxWE5jOm1vfXv5ylufFscmQeInuagNxeJ9e6bD8g6m/CKJHEJgqYUhQ2l0A+gKF6CzbUzNLwCNnFsQIHm2f77K3Y+TbY7fkadzWDApNYucaLx3HyC8yxBm10F6E/X6OQIR6LU1musR1Sz1BmPasC80MR8wZr8uWTH5znyZJv5g4rE0qgFAefRd3c1P9QEVZBGdcDIG+F6IVgcOorq0H2zJx9piKJgAMMGVRb/rFw2y/cjG0jKP0o1bE0Ka59ENBJ8C6O55pjjOhwIDAQABAoIBAFetNfz1R7hbxjlFshMAkVzQR8wvT9qbvl+dtzdZRcaFhu89NecDIP7+QDYor0FcxoGpU0TazDyRQyk2BQD8vHt+9zv9BVLtZLJSqoWgPbUFBi1DjS8EF2ka8RVYnn35NhUhhd7L//ftL88Bh673mfembQ9srDjoEy1Z01feoABAnCMkNFl986DmEwnarvEufXSDIgeN4ioMxha4NvfIPuI0zpVdV1O9sv+SGC+VEWZBtN3GNsaf4zS/f8FVGvTiU/Abz0gSw/iwSPHclDWQDTN3yFHf/tfqlzh0mH0WfhnuOBFWXzK+R7fbnM+asI9ttvzRcfpzgRGXdPcNcOv/6cECgYEA3DVqpi1k8MYfJixju6SG5gfyhM4VFksFmCMaNPgtatDMBKLMTgV/Ej6LXREojcy29uZl83F09pVlpd41eG39ULIPktixA/BqErQ2UaWh6kOxifycpu22Jh0r09hax6UgVrcBrrnCJEjcFsuJlrZvXQSzc3PBxjWy5gjabS5h9iECgYEAzmVAIh2frF01Y95zsLueAhhZwCtPanm6kf7ivR4r1plIX3b2sNRhWGmEHFgaCE6Braa0ogQ73Hd26kw4ZW+D6QMGC/zjCBEzDLLf++SjdVUHiY5AR4WHqXzq1jdAlsVyo9R661oAOp3lhiJVGLNXkHyEfEVPHsaxJh4osYSbX6cCgYEAx32Qx0i6eDFTyLZQB46uMrgiaVN04QRH5iJuvGvUYT8UhGKjaU8rZfDJOh+wOH2rhxMEaz1uc3C2bERY9mfWI4Ob/jFWc7YZsiYWS3Mcsuhubw4tMECLUg39RWZsHw8ls8kIuixIh6yFzhTH6YQOcRswIrhMZG8DScfdcSmiz2ECgYEAkWP1t5KSpkLKl11etcKUXfl1T8+yk9jIOowIgRw92WAFAWq2AH67TCKYM7dEL1HOO9tRJ0hAOt/U3ttuZtYVYBEHM26jJ02mXm2rJrA7DS4mrxmL4lYH6LbcXqZxU0Qnq4zEQgIWYzRTORf6Rfof1uJAGaJhR9bDd4yLMfGt2cUCgYEAo216Y61xOHUTA4AF1eekk+r+uOcQgQDvLXfs9FkDdJLk0mPG48/+eIYpPFnANJ/riF/DWOp8WGEe2IzA9yUFexzDbNQK8ha9kGcxaSAyiCwzjZ/t9/+hScDSV8kNqWSRSisu/YOFleEHbokT6mbLZ+gdqES8mUUanaEBzRQYGxo=
 | 
			
		||||
							
								
								
									
										
											BIN
										
									
								
								api/res/img/wechat-pay.jpg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										
											BIN
										
									
								
								api/res/img/wechat-pay.jpg
									
									
									
									
									
										Normal file
									
								
							
										
											Binary file not shown.
										
									
								
							| 
		 After Width: | Height: | Size: 5.7 KiB  | 
@@ -1,67 +0,0 @@
 | 
			
		||||
{
 | 
			
		||||
  "data": [
 | 
			
		||||
    "task(m1wpaa4v60zedj8)",
 | 
			
		||||
    "a cute cat",
 | 
			
		||||
    "",
 | 
			
		||||
    [],
 | 
			
		||||
    20,
 | 
			
		||||
    "DPM++ 2M Karras",
 | 
			
		||||
    1,
 | 
			
		||||
    1,
 | 
			
		||||
    7,
 | 
			
		||||
    512,
 | 
			
		||||
    384,
 | 
			
		||||
    true,
 | 
			
		||||
    0.7,
 | 
			
		||||
    2,
 | 
			
		||||
    "ESRGAN_4x",
 | 
			
		||||
    10,
 | 
			
		||||
    0,
 | 
			
		||||
    0,
 | 
			
		||||
    "Use same checkpoint",
 | 
			
		||||
    "Use same sampler",
 | 
			
		||||
    "",
 | 
			
		||||
    "",
 | 
			
		||||
    [],
 | 
			
		||||
    "None",
 | 
			
		||||
    false,
 | 
			
		||||
    "",
 | 
			
		||||
    0.8,
 | 
			
		||||
    -1,
 | 
			
		||||
    false,
 | 
			
		||||
    -1,
 | 
			
		||||
    0,
 | 
			
		||||
    0,
 | 
			
		||||
    0,
 | 
			
		||||
    false,
 | 
			
		||||
    false,
 | 
			
		||||
    "positive",
 | 
			
		||||
    "comma",
 | 
			
		||||
    0,
 | 
			
		||||
    false,
 | 
			
		||||
    false,
 | 
			
		||||
    "",
 | 
			
		||||
    "Seed",
 | 
			
		||||
    "",
 | 
			
		||||
    [],
 | 
			
		||||
    "Nothing",
 | 
			
		||||
    "",
 | 
			
		||||
    [],
 | 
			
		||||
    "Nothing",
 | 
			
		||||
    "",
 | 
			
		||||
    [],
 | 
			
		||||
    true,
 | 
			
		||||
    false,
 | 
			
		||||
    false,
 | 
			
		||||
    false,
 | 
			
		||||
    0,
 | 
			
		||||
    false,
 | 
			
		||||
    [],
 | 
			
		||||
    "",
 | 
			
		||||
    "",
 | 
			
		||||
    ""
 | 
			
		||||
  ],
 | 
			
		||||
  "event_data": null,
 | 
			
		||||
  "fn_index": 96,
 | 
			
		||||
  "session_hash": "kmb0ojjfhdj"
 | 
			
		||||
}
 | 
			
		||||
@@ -1,19 +1,26 @@
 | 
			
		||||
package service
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type CaptchaService struct {
 | 
			
		||||
	config types.ChatPlusApiConfig
 | 
			
		||||
	config types.ApiConfig
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewCaptchaService(config types.ChatPlusApiConfig) *CaptchaService {
 | 
			
		||||
func NewCaptchaService(config types.ApiConfig) *CaptchaService {
 | 
			
		||||
	return &CaptchaService{
 | 
			
		||||
		config: config,
 | 
			
		||||
		client: req.C().SetTimeout(10 * time.Second),
 | 
			
		||||
@@ -60,3 +67,44 @@ func (s *CaptchaService) Check(data interface{}) bool {
 | 
			
		||||
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *CaptchaService) SlideGet() (interface{}, error) {
 | 
			
		||||
	if s.config.Token == "" {
 | 
			
		||||
		return nil, errors.New("无效的 API Token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/captcha/slide/get", s.config.ApiURL)
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	r, err := s.client.R().
 | 
			
		||||
		SetHeader("AppId", s.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return nil, fmt.Errorf("请求 API 失败:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return nil, fmt.Errorf("请求 API 失败:%s", res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res.Data, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *CaptchaService) SlideCheck(data interface{}) bool {
 | 
			
		||||
	url := fmt.Sprintf("%s/api/captcha/slide/check", s.config.ApiURL)
 | 
			
		||||
	var res types.BizVo
 | 
			
		||||
	r, err := s.client.R().
 | 
			
		||||
		SetHeader("AppId", s.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", s.config.Token)).
 | 
			
		||||
		SetBodyJsonMarshal(data).
 | 
			
		||||
		SetSuccessResult(&res).Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return false
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return true
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										307
									
								
								api/service/dalle/service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										307
									
								
								api/service/dalle/service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,307 @@
 | 
			
		||||
package dalle
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
// DALL-E 绘画服务
 | 
			
		||||
 | 
			
		||||
type Service struct {
 | 
			
		||||
	httpClient    *req.Client
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	notifyQueue   *store.RedisQueue
 | 
			
		||||
	Clients       *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(db *gorm.DB, manager *oss.UploaderManager, redisCli *redis.Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		httpClient:    req.C().SetTimeout(time.Minute * 3),
 | 
			
		||||
		db:            db,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("DallE_Task_Queue", redisCli),
 | 
			
		||||
		notifyQueue:   store.NewRedisQueue("DallE_Notify_Queue", redisCli),
 | 
			
		||||
		Clients:       types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PushTask push a new mj task in to task queue
 | 
			
		||||
func (s *Service) PushTask(task types.DallTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var task types.DallTask
 | 
			
		||||
			err := s.taskQueue.LPop(&task)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			_, err = s.Image(task, false)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Errorf("error with image task: %v", err)
 | 
			
		||||
				s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
					"progress": -1,
 | 
			
		||||
					"err_msg":  err.Error(),
 | 
			
		||||
				})
 | 
			
		||||
				s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Failed})
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type imgReq struct {
 | 
			
		||||
	Model   string `json:"model"`
 | 
			
		||||
	Prompt  string `json:"prompt"`
 | 
			
		||||
	N       int    `json:"n"`
 | 
			
		||||
	Size    string `json:"size"`
 | 
			
		||||
	Quality string `json:"quality"`
 | 
			
		||||
	Style   string `json:"style"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type imgRes struct {
 | 
			
		||||
	Created int64 `json:"created"`
 | 
			
		||||
	Data    []struct {
 | 
			
		||||
		RevisedPrompt string `json:"revised_prompt"`
 | 
			
		||||
		Url           string `json:"url"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Code    interface{} `json:"code"`
 | 
			
		||||
		Message string      `json:"message"`
 | 
			
		||||
		Param   interface{} `json:"param"`
 | 
			
		||||
		Type    string      `json:"type"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Image(task types.DallTask, sync bool) (string, error) {
 | 
			
		||||
	logger.Debugf("绘画参数:%+v", task)
 | 
			
		||||
	prompt := task.Prompt
 | 
			
		||||
	// translate prompt
 | 
			
		||||
	if utils.HasChinese(task.Prompt) {
 | 
			
		||||
		content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", fmt.Errorf("error with translate prompt: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		prompt = content
 | 
			
		||||
		logger.Debugf("重写后提示词:%s", prompt)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var user model.User
 | 
			
		||||
	s.db.Where("id", task.UserId).First(&user)
 | 
			
		||||
	if user.Power < task.Power {
 | 
			
		||||
		return "", errors.New("insufficient of power")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// get image generation API KEY
 | 
			
		||||
	var apiKey model.ApiKey
 | 
			
		||||
	tx := s.db.Where("platform", types.OpenAI).
 | 
			
		||||
		Where("type", "img").
 | 
			
		||||
		Where("enabled", true).
 | 
			
		||||
		Order("last_used_at ASC").First(&apiKey)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return "", fmt.Errorf("no available IMG api key: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res imgRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	if len(apiKey.ProxyURL) > 5 {
 | 
			
		||||
		s.httpClient.SetProxyURL(apiKey.ProxyURL).R()
 | 
			
		||||
	}
 | 
			
		||||
	logger.Debugf("Sending %s request, ApiURL:%s, API KEY:%s, PROXY: %s", apiKey.Platform, apiKey.ApiURL, apiKey.Value, apiKey.ProxyURL)
 | 
			
		||||
	r, err := s.httpClient.R().SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+apiKey.Value).
 | 
			
		||||
		SetBody(imgReq{
 | 
			
		||||
			Model:   "dall-e-3",
 | 
			
		||||
			Prompt:  prompt,
 | 
			
		||||
			N:       1,
 | 
			
		||||
			Size:    "1024x1024",
 | 
			
		||||
			Style:   task.Style,
 | 
			
		||||
			Quality: task.Quality,
 | 
			
		||||
		}).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		SetSuccessResult(&res).Post(apiKey.ApiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with send request: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("error with send request: %v", errRes.Error)
 | 
			
		||||
	}
 | 
			
		||||
	// update the api key last use time
 | 
			
		||||
	s.db.Model(&apiKey).UpdateColumn("last_used_at", time.Now().Unix())
 | 
			
		||||
	// update task progress
 | 
			
		||||
	s.db.Model(&model.DallJob{Id: task.JobId}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
		"progress": 100,
 | 
			
		||||
		"org_url":  res.Data[0].Url,
 | 
			
		||||
		"prompt":   prompt,
 | 
			
		||||
	})
 | 
			
		||||
 | 
			
		||||
	s.notifyQueue.RPush(sd.NotifyMessage{UserId: int(task.UserId), JobId: int(task.JobId), Message: sd.Finished})
 | 
			
		||||
	var content string
 | 
			
		||||
	if sync {
 | 
			
		||||
		imgURL, err := s.downloadImage(task.JobId, int(task.UserId), res.Data[0].Url)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			return "", fmt.Errorf("error with download image: %v", err)
 | 
			
		||||
		}
 | 
			
		||||
		content = fmt.Sprintf("```\n%s\n```\n下面是我为你创作的图片:\n\n\n", prompt, imgURL)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户算力
 | 
			
		||||
	tx = s.db.Model(&model.User{}).Where("id", user.Id).UpdateColumn("power", gorm.Expr("power - ?", task.Power))
 | 
			
		||||
	// 记录算力变化日志
 | 
			
		||||
	if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
		var u model.User
 | 
			
		||||
		s.db.Where("id", user.Id).First(&u)
 | 
			
		||||
		s.db.Create(&model.PowerLog{
 | 
			
		||||
			UserId:    user.Id,
 | 
			
		||||
			Username:  user.Username,
 | 
			
		||||
			Type:      types.PowerConsume,
 | 
			
		||||
			Amount:    task.Power,
 | 
			
		||||
			Balance:   u.Power,
 | 
			
		||||
			Mark:      types.PowerSub,
 | 
			
		||||
			Model:     "dall-e-3",
 | 
			
		||||
			Remark:    fmt.Sprintf("绘画提示词:%s", utils.CutWords(task.Prompt, 10)),
 | 
			
		||||
			CreatedAt: time.Now(),
 | 
			
		||||
		})
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return content, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running DALL-E task notify checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var message sd.NotifyMessage
 | 
			
		||||
			err := s.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			client := s.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if client == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = client.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) DownloadImages() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.DallJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := s.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// download images
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				if v.OrgURL == "" {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Infof("try to download image: %s", v.OrgURL)
 | 
			
		||||
				imgURL, err := s.downloadImage(v.Id, int(v.UserId), v.OrgURL)
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Error("error with download image: %s, error: %v", imgURL, err)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) downloadImage(jobId uint, userId int, orgURL string) (string, error) {
 | 
			
		||||
	// sava image
 | 
			
		||||
	imgURL, err := s.uploadManager.GetUploadHandler().PutImg(orgURL, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// update img_url
 | 
			
		||||
	res := s.db.Model(&model.DallJob{Id: jobId}).UpdateColumn("img_url", imgURL)
 | 
			
		||||
	if res.Error != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	s.notifyQueue.RPush(sd.NotifyMessage{UserId: userId, JobId: int(jobId), Message: sd.Finished})
 | 
			
		||||
	return imgURL, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// CheckTaskStatus 检查任务状态,自动删除过期或者失败的任务
 | 
			
		||||
func (s *Service) CheckTaskStatus() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		logger.Info("Running Stable-Diffusion task status checking ...")
 | 
			
		||||
		for {
 | 
			
		||||
			var jobs []model.DallJob
 | 
			
		||||
			res := s.db.Where("progress < ?", 100).Find(&jobs)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				time.Sleep(5 * time.Second)
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range jobs {
 | 
			
		||||
				// 5 分钟还没完成的任务直接删除
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*5 || job.Progress == -1 {
 | 
			
		||||
					s.db.Delete(&job)
 | 
			
		||||
					var user model.User
 | 
			
		||||
					s.db.Where("id = ?", job.UserId).First(&user)
 | 
			
		||||
					// 退回绘图次数
 | 
			
		||||
					res = s.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
 | 
			
		||||
					if res.Error == nil && res.RowsAffected > 0 {
 | 
			
		||||
						s.db.Create(&model.PowerLog{
 | 
			
		||||
							UserId:    user.Id,
 | 
			
		||||
							Username:  user.Username,
 | 
			
		||||
							Type:      types.PowerConsume,
 | 
			
		||||
							Amount:    job.Power,
 | 
			
		||||
							Balance:   user.Power + job.Power,
 | 
			
		||||
							Mark:      types.PowerAdd,
 | 
			
		||||
							Model:     "dall-e-3",
 | 
			
		||||
							Remark:    fmt.Sprintf("任务失败,退回算力。任务ID:%d", job.Id),
 | 
			
		||||
							CreatedAt: time.Now(),
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
@@ -1,42 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// AI 绘画函数
 | 
			
		||||
 | 
			
		||||
type FuncMidJourney struct {
 | 
			
		||||
	name    string
 | 
			
		||||
	service *mj.Service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewMidJourneyFunc(mjService *mj.Service) FuncMidJourney {
 | 
			
		||||
	return FuncMidJourney{
 | 
			
		||||
		name:    "MidJourney AI 绘画",
 | 
			
		||||
		service: mjService}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncMidJourney) Invoke(params map[string]interface{}) (string, error) {
 | 
			
		||||
	logger.Infof("MJ 绘画参数:%+v", params)
 | 
			
		||||
	prompt := utils.InterfaceToString(params["prompt"])
 | 
			
		||||
	f.service.PushTask(types.MjTask{
 | 
			
		||||
		SessionId: utils.InterfaceToString(params["session_id"]),
 | 
			
		||||
		Src:       types.TaskSrcChat,
 | 
			
		||||
		Type:      types.TaskImage,
 | 
			
		||||
		Prompt:    prompt,
 | 
			
		||||
		UserId:    utils.IntValue(utils.InterfaceToString(params["user_id"]), 0),
 | 
			
		||||
		RoleId:    utils.IntValue(utils.InterfaceToString(params["role_id"]), 0),
 | 
			
		||||
		Icon:      utils.InterfaceToString(params["icon"]),
 | 
			
		||||
		ChatId:    utils.InterfaceToString(params["chat_id"]),
 | 
			
		||||
	})
 | 
			
		||||
	return prompt, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncMidJourney) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncMidJourney{}
 | 
			
		||||
@@ -1,39 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/service/mj"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Function interface {
 | 
			
		||||
	Invoke(map[string]interface{}) (string, error)
 | 
			
		||||
	Name() string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type resVo struct {
 | 
			
		||||
	Code    types.BizCode `json:"code"`
 | 
			
		||||
	Message string        `json:"message"`
 | 
			
		||||
	Data    struct {
 | 
			
		||||
		Title     string     `json:"title"`
 | 
			
		||||
		UpdatedAt string     `json:"updated_at"`
 | 
			
		||||
		Items     []dataItem `json:"items"`
 | 
			
		||||
	} `json:"data"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type dataItem struct {
 | 
			
		||||
	Title  string `json:"title"`
 | 
			
		||||
	Url    string `json:"url"`
 | 
			
		||||
	Remark string `json:"remark"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewFunctions(config *types.AppConfig, mjService *mj.Service) map[string]Function {
 | 
			
		||||
	return map[string]Function{
 | 
			
		||||
		types.FuncZaoBao:     NewZaoBao(config.ApiConfig),
 | 
			
		||||
		types.FuncWeibo:      NewWeiboHot(config.ApiConfig),
 | 
			
		||||
		types.FuncHeadLine:   NewHeadLines(config.ApiConfig),
 | 
			
		||||
		types.FuncMidJourney: NewMidJourneyFunc(mjService),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
@@ -1,58 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 今日头条函数实现
 | 
			
		||||
 | 
			
		||||
type FuncHeadlines struct {
 | 
			
		||||
	name   string
 | 
			
		||||
	config types.ChatPlusApiConfig
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHeadLines(config types.ChatPlusApiConfig) FuncHeadlines {
 | 
			
		||||
	return FuncHeadlines{
 | 
			
		||||
		name:   "今日头条",
 | 
			
		||||
		config: config,
 | 
			
		||||
		client: req.C().SetTimeout(10 * time.Second)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncHeadlines) Invoke(map[string]interface{}) (string, error) {
 | 
			
		||||
	if f.config.Token == "" {
 | 
			
		||||
		return "", errors.New("无效的 API Token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/headline/fetch", f.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := f.client.R().
 | 
			
		||||
		SetHeader("AppId", f.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("%v%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return "", errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
 | 
			
		||||
	for i, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [%s]", i+1, v.Title, v.Url, v.Remark))
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Join(builder, "\n\n"), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncHeadlines) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncHeadlines{}
 | 
			
		||||
@@ -1,58 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 微博热搜函数实现
 | 
			
		||||
 | 
			
		||||
type FuncWeiboHot struct {
 | 
			
		||||
	name   string
 | 
			
		||||
	config types.ChatPlusApiConfig
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewWeiboHot(config types.ChatPlusApiConfig) FuncWeiboHot {
 | 
			
		||||
	return FuncWeiboHot{
 | 
			
		||||
		name:   "微博热搜",
 | 
			
		||||
		config: config,
 | 
			
		||||
		client: req.C().SetTimeout(10 * time.Second)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncWeiboHot) Invoke(map[string]interface{}) (string, error) {
 | 
			
		||||
	if f.config.Token == "" {
 | 
			
		||||
		return "", errors.New("无效的 API Token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/weibo/fetch", f.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := f.client.R().
 | 
			
		||||
		SetHeader("AppId", f.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("%v%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return "", errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s**,最新更新:%s", res.Data.Title, res.Data.UpdatedAt))
 | 
			
		||||
	for i, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, fmt.Sprintf("%d、 [%s](%s) [热度:%s]", i+1, v.Title, v.Url, v.Remark))
 | 
			
		||||
	}
 | 
			
		||||
	return strings.Join(builder, "\n\n"), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncWeiboHot) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncWeiboHot{}
 | 
			
		||||
@@ -1,59 +0,0 @@
 | 
			
		||||
package fun
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// 每日早报函数实现
 | 
			
		||||
 | 
			
		||||
type FuncZaoBao struct {
 | 
			
		||||
	name   string
 | 
			
		||||
	config types.ChatPlusApiConfig
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewZaoBao(config types.ChatPlusApiConfig) FuncZaoBao {
 | 
			
		||||
	return FuncZaoBao{
 | 
			
		||||
		name:   "每日早报",
 | 
			
		||||
		config: config,
 | 
			
		||||
		client: req.C().SetTimeout(10 * time.Second)}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncZaoBao) Invoke(map[string]interface{}) (string, error) {
 | 
			
		||||
	if f.config.Token == "" {
 | 
			
		||||
		return "", errors.New("无效的 API Token")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := fmt.Sprintf("%s/api/zaobao/fetch", f.config.ApiURL)
 | 
			
		||||
	var res resVo
 | 
			
		||||
	r, err := f.client.R().
 | 
			
		||||
		SetHeader("AppId", f.config.AppId).
 | 
			
		||||
		SetHeader("Authorization", fmt.Sprintf("Bearer %s", f.config.Token)).
 | 
			
		||||
		SetSuccessResult(&res).Get(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return "", fmt.Errorf("%v%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.Code != types.Success {
 | 
			
		||||
		return "", errors.New(res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	builder := make([]string, 0)
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("**%s 早报:**", res.Data.UpdatedAt))
 | 
			
		||||
	for _, v := range res.Data.Items {
 | 
			
		||||
		builder = append(builder, v.Title)
 | 
			
		||||
	}
 | 
			
		||||
	builder = append(builder, fmt.Sprintf("%s", res.Data.Title))
 | 
			
		||||
	return strings.Join(builder, "\n\n"), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (f FuncZaoBao) Name() string {
 | 
			
		||||
	return f.name
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Function = &FuncZaoBao{}
 | 
			
		||||
@@ -1,213 +0,0 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"github.com/bwmarrin/discordgo"
 | 
			
		||||
	"github.com/gorilla/websocket"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"regexp"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MidJourney 机器人
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
type Bot struct {
 | 
			
		||||
	config  *types.MidJourneyConfig
 | 
			
		||||
	bot     *discordgo.Session
 | 
			
		||||
	service *Service
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewBot(config *types.AppConfig, service *Service) (*Bot, error) {
 | 
			
		||||
	discord, err := discordgo.New("Bot " + config.MjConfig.BotToken)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.ProxyURL != "" {
 | 
			
		||||
		proxy, _ := url.Parse(config.ProxyURL)
 | 
			
		||||
		discord.Client = &http.Client{
 | 
			
		||||
			Transport: &http.Transport{
 | 
			
		||||
				Proxy: http.ProxyURL(proxy),
 | 
			
		||||
			},
 | 
			
		||||
		}
 | 
			
		||||
		discord.Dialer = &websocket.Dialer{
 | 
			
		||||
			Proxy: http.ProxyURL(proxy),
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &Bot{
 | 
			
		||||
		config:  &config.MjConfig,
 | 
			
		||||
		bot:     discord,
 | 
			
		||||
		service: service,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) Run() error {
 | 
			
		||||
	b.bot.Identify.Intents = discordgo.IntentsAllWithoutPrivileged | discordgo.IntentsGuildMessages | discordgo.IntentMessageContent
 | 
			
		||||
	b.bot.AddHandler(b.messageCreate)
 | 
			
		||||
	b.bot.AddHandler(b.messageUpdate)
 | 
			
		||||
 | 
			
		||||
	logger.Info("Starting MidJourney Bot...")
 | 
			
		||||
	err := b.bot.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("Error opening Discord connection:", err)
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("Starting MidJourney Bot successfully!")
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type TaskStatus string
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	Start    = TaskStatus("Started")
 | 
			
		||||
	Running  = TaskStatus("Running")
 | 
			
		||||
	Stopped  = TaskStatus("Stopped")
 | 
			
		||||
	Finished = TaskStatus("Finished")
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type Image struct {
 | 
			
		||||
	URL      string `json:"url"`
 | 
			
		||||
	ProxyURL string `json:"proxy_url"`
 | 
			
		||||
	Filename string `json:"filename"`
 | 
			
		||||
	Width    int    `json:"width"`
 | 
			
		||||
	Height   int    `json:"height"`
 | 
			
		||||
	Size     int    `json:"size"`
 | 
			
		||||
	Hash     string `json:"hash"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) messageCreate(s *discordgo.Session, m *discordgo.MessageCreate) {
 | 
			
		||||
	// ignore messages for other channels
 | 
			
		||||
	if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("CREATE: %s", utils.JsonEncode(m))
 | 
			
		||||
	var referenceId = ""
 | 
			
		||||
	if m.ReferencedMessage != nil {
 | 
			
		||||
		referenceId = m.ReferencedMessage.ID
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(m.Content, "(Waiting to start)") && !strings.Contains(m.Content, "Rerolling **") {
 | 
			
		||||
		// parse content
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
			Content:     m.Content,
 | 
			
		||||
			Progress:    0,
 | 
			
		||||
			Status:      Start}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) messageUpdate(s *discordgo.Session, m *discordgo.MessageUpdate) {
 | 
			
		||||
	// ignore messages for other channels
 | 
			
		||||
	if m.GuildID != b.config.GuildId || m.ChannelID != b.config.ChanelId {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
	// ignore messages for self
 | 
			
		||||
	if m.Author.ID == s.State.User.ID {
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	logger.Debugf("UPDATE: %s", utils.JsonEncode(m))
 | 
			
		||||
 | 
			
		||||
	var referenceId = ""
 | 
			
		||||
	if m.ReferencedMessage != nil {
 | 
			
		||||
		referenceId = m.ReferencedMessage.ID
 | 
			
		||||
	}
 | 
			
		||||
	if strings.Contains(m.Content, "(Stopped)") {
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			MessageId:   m.ID,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Prompt:      extractPrompt(m.Content),
 | 
			
		||||
			Content:     m.Content,
 | 
			
		||||
			Progress:    extractProgress(m.Content),
 | 
			
		||||
			Status:      Stopped}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	b.addAttachment(m.ID, referenceId, m.Content, m.Attachments)
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (b *Bot) addAttachment(messageId string, referenceId string, content string, attachments []*discordgo.MessageAttachment) {
 | 
			
		||||
	progress := extractProgress(content)
 | 
			
		||||
	var status TaskStatus
 | 
			
		||||
	if progress == 100 {
 | 
			
		||||
		status = Finished
 | 
			
		||||
	} else {
 | 
			
		||||
		status = Running
 | 
			
		||||
	}
 | 
			
		||||
	for _, attachment := range attachments {
 | 
			
		||||
		if attachment.Width == 0 || attachment.Height == 0 {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		image := Image{
 | 
			
		||||
			URL:      attachment.URL,
 | 
			
		||||
			Height:   attachment.Height,
 | 
			
		||||
			ProxyURL: attachment.ProxyURL,
 | 
			
		||||
			Width:    attachment.Width,
 | 
			
		||||
			Size:     attachment.Size,
 | 
			
		||||
			Filename: attachment.Filename,
 | 
			
		||||
			Hash:     extractHashFromFilename(attachment.Filename),
 | 
			
		||||
		}
 | 
			
		||||
		req := CBReq{
 | 
			
		||||
			MessageId:   messageId,
 | 
			
		||||
			ReferenceId: referenceId,
 | 
			
		||||
			Image:       image,
 | 
			
		||||
			Prompt:      extractPrompt(content),
 | 
			
		||||
			Content:     content,
 | 
			
		||||
			Progress:    progress,
 | 
			
		||||
			Status:      status,
 | 
			
		||||
		}
 | 
			
		||||
		b.service.Notify(req)
 | 
			
		||||
		break // only get one image
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// extract prompt from string
 | 
			
		||||
func extractPrompt(input string) string {
 | 
			
		||||
	pattern := `\*\*(.*?)\*\*`
 | 
			
		||||
	re := regexp.MustCompile(pattern)
 | 
			
		||||
	matches := re.FindStringSubmatch(input)
 | 
			
		||||
	if len(matches) > 1 {
 | 
			
		||||
		return strings.TrimSpace(matches[1])
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func extractProgress(input string) int {
 | 
			
		||||
	pattern := `\((\d+)\%\)`
 | 
			
		||||
	re := regexp.MustCompile(pattern)
 | 
			
		||||
	matches := re.FindStringSubmatch(input)
 | 
			
		||||
	if len(matches) > 1 {
 | 
			
		||||
		return utils.IntValue(matches[1], 0)
 | 
			
		||||
	}
 | 
			
		||||
	return 100
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func extractHashFromFilename(filename string) string {
 | 
			
		||||
	if !strings.HasSuffix(filename, ".png") {
 | 
			
		||||
		return ""
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	index := strings.LastIndex(filename, "_")
 | 
			
		||||
	if index != -1 {
 | 
			
		||||
		return filename[index+1 : len(filename)-4]
 | 
			
		||||
	}
 | 
			
		||||
	return ""
 | 
			
		||||
}
 | 
			
		||||
@@ -1,144 +1,68 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
// MidJourney client
 | 
			
		||||
import "geekai/core/types"
 | 
			
		||||
 | 
			
		||||
type Client struct {
 | 
			
		||||
	client *req.Client
 | 
			
		||||
	config *types.MidJourneyConfig
 | 
			
		||||
type Client interface {
 | 
			
		||||
	Imagine(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Blend(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	SwapFace(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Upscale(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	Variation(task types.MjTask) (ImageRes, error)
 | 
			
		||||
	QueryTask(taskId string) (QueryRes, error)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewClient(config *types.AppConfig) *Client {
 | 
			
		||||
	client := req.C().SetTimeout(10 * time.Second)
 | 
			
		||||
	// set proxy URL
 | 
			
		||||
	if config.ProxyURL != "" {
 | 
			
		||||
		client.SetProxyURL(config.ProxyURL)
 | 
			
		||||
	}
 | 
			
		||||
	return &Client{client: client, config: &config.MjConfig}
 | 
			
		||||
type ImageReq struct {
 | 
			
		||||
	BotType       string      `json:"botType,omitempty"`
 | 
			
		||||
	Prompt        string      `json:"prompt,omitempty"`
 | 
			
		||||
	Dimensions    string      `json:"dimensions,omitempty"`
 | 
			
		||||
	Base64Array   []string    `json:"base64Array,omitempty"`
 | 
			
		||||
	AccountFilter interface{} `json:"accountFilter,omitempty"`
 | 
			
		||||
	NotifyHook    string      `json:"notifyHook,omitempty"`
 | 
			
		||||
	State         string      `json:"state,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *Client) Imagine(prompt string) error {
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          2,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.config.GuildId,
 | 
			
		||||
		ChannelID:     c.config.ChanelId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"version": "1166847114203123795",
 | 
			
		||||
			"id":      "938956540159881230",
 | 
			
		||||
			"name":    "imagine",
 | 
			
		||||
			"type":    "1",
 | 
			
		||||
			"options": []map[string]any{
 | 
			
		||||
				{
 | 
			
		||||
					"type":  3,
 | 
			
		||||
					"name":  "prompt",
 | 
			
		||||
					"value": prompt,
 | 
			
		||||
				},
 | 
			
		||||
			},
 | 
			
		||||
			"application_command": map[string]any{
 | 
			
		||||
				"id":                         "938956540159881230",
 | 
			
		||||
				"application_id":             ApplicationID,
 | 
			
		||||
				"version":                    "1118961510123847772",
 | 
			
		||||
				"default_permission":         true,
 | 
			
		||||
				"default_member_permissions": nil,
 | 
			
		||||
				"type":                       1,
 | 
			
		||||
				"nsfw":                       false,
 | 
			
		||||
				"name":                       "imagine",
 | 
			
		||||
				"description":                "Create images with Midjourney",
 | 
			
		||||
				"dm_permission":              true,
 | 
			
		||||
				"options": []map[string]any{
 | 
			
		||||
					{
 | 
			
		||||
						"type":        3,
 | 
			
		||||
						"name":        "prompt",
 | 
			
		||||
						"description": "The prompt to imagine",
 | 
			
		||||
						"required":    true,
 | 
			
		||||
					},
 | 
			
		||||
				},
 | 
			
		||||
				"attachments": []any{},
 | 
			
		||||
			},
 | 
			
		||||
		},
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := "https://discord.com/api/v9/interactions"
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		Post(url)
 | 
			
		||||
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %w%v", err, r.Err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
type ImageRes struct {
 | 
			
		||||
	Code        int    `json:"code"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
	Result string `json:"result"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *Client) Upscale(index int, messageId string, hash string) error {
 | 
			
		||||
	flags := 0
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          3,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.config.GuildId,
 | 
			
		||||
		ChannelID:     c.config.ChanelId,
 | 
			
		||||
		MessageFlags:  &flags,
 | 
			
		||||
		MessageID:     &messageId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"component_type": 2,
 | 
			
		||||
			"custom_id":      fmt.Sprintf("MJ::JOB::upsample::%d::%s", index, hash),
 | 
			
		||||
		},
 | 
			
		||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := "https://discord.com/api/v9/interactions"
 | 
			
		||||
	var res InteractionsResult
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		SetErrorResult(&res).
 | 
			
		||||
		Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
type ErrRes struct {
 | 
			
		||||
	Error struct {
 | 
			
		||||
		Message string `json:"message"`
 | 
			
		||||
	} `json:"error"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *Client) Variation(index int, messageId string, hash string) error {
 | 
			
		||||
	flags := 0
 | 
			
		||||
	interactionsReq := &InteractionsRequest{
 | 
			
		||||
		Type:          3,
 | 
			
		||||
		ApplicationID: ApplicationID,
 | 
			
		||||
		GuildID:       c.config.GuildId,
 | 
			
		||||
		ChannelID:     c.config.ChanelId,
 | 
			
		||||
		MessageFlags:  &flags,
 | 
			
		||||
		MessageID:     &messageId,
 | 
			
		||||
		SessionID:     SessionID,
 | 
			
		||||
		Data: map[string]any{
 | 
			
		||||
			"component_type": 2,
 | 
			
		||||
			"custom_id":      fmt.Sprintf("MJ::JOB::variation::%d::%s", index, hash),
 | 
			
		||||
		},
 | 
			
		||||
		Nonce: fmt.Sprintf("%d", time.Now().UnixNano()),
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	url := "https://discord.com/api/v9/interactions"
 | 
			
		||||
	var res InteractionsResult
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", c.config.UserToken).
 | 
			
		||||
		SetHeader("Content-Type", "application/json").
 | 
			
		||||
		SetBody(interactionsReq).
 | 
			
		||||
		SetErrorResult(&res).
 | 
			
		||||
		Post(url)
 | 
			
		||||
	if err != nil || r.IsErrorState() {
 | 
			
		||||
		return fmt.Errorf("error with http request: %v%v%v", err, r.Err, res.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return nil
 | 
			
		||||
type QueryRes struct {
 | 
			
		||||
	Action  string `json:"action"`
 | 
			
		||||
	Buttons []struct {
 | 
			
		||||
		CustomId string `json:"customId"`
 | 
			
		||||
		Emoji    string `json:"emoji"`
 | 
			
		||||
		Label    string `json:"label"`
 | 
			
		||||
		Style    int    `json:"style"`
 | 
			
		||||
		Type     int    `json:"type"`
 | 
			
		||||
	} `json:"buttons"`
 | 
			
		||||
	Description string `json:"description"`
 | 
			
		||||
	FailReason  string `json:"failReason"`
 | 
			
		||||
	FinishTime  int    `json:"finishTime"`
 | 
			
		||||
	Id          string `json:"id"`
 | 
			
		||||
	ImageUrl    string `json:"imageUrl"`
 | 
			
		||||
	Progress    string `json:"progress"`
 | 
			
		||||
	Prompt      string `json:"prompt"`
 | 
			
		||||
	PromptEn    string `json:"promptEn"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
	StartTime  int    `json:"startTime"`
 | 
			
		||||
	State      string `json:"state"`
 | 
			
		||||
	Status     string `json:"status"`
 | 
			
		||||
	SubmitTime int    `json:"submitTime"`
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										240
									
								
								api/service/mj/plus_client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										240
									
								
								api/service/mj/plus_client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,240 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"io"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// PlusClient MidJourney Plus ProxyClient
 | 
			
		||||
type PlusClient struct {
 | 
			
		||||
	Config types.MjPlusConfig
 | 
			
		||||
	apiURL string
 | 
			
		||||
	client *req.Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPlusClient(config types.MjPlusConfig) *PlusClient {
 | 
			
		||||
	return &PlusClient{
 | 
			
		||||
		Config: config,
 | 
			
		||||
		apiURL: config.ApiURL,
 | 
			
		||||
		client: req.C().SetTimeout(time.Minute).SetUserAgent("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36"),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *PlusClient) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/imagine", c.apiURL, c.Config.Mode)
 | 
			
		||||
	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
 | 
			
		||||
	if task.NegPrompt != "" {
 | 
			
		||||
		prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
 | 
			
		||||
	}
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Prompt:      prompt,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blend 融图
 | 
			
		||||
func (c *PlusClient) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/blend", c.apiURL, c.Config.Mode)
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		BotType:     "MID_JOURNEY",
 | 
			
		||||
		Dimensions:  "SQUARE",
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		for _, imgURL := range task.ImgArr {
 | 
			
		||||
			imageData, err := utils.DownloadImage(imgURL, "")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
			} else {
 | 
			
		||||
				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SwapFace 换脸
 | 
			
		||||
func (c *PlusClient) SwapFace(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/insight-face/swap", c.apiURL, c.Config.Mode)
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) != 2 {
 | 
			
		||||
		return ImageRes{}, errors.New("参数错误,必须上传2张图片")
 | 
			
		||||
	}
 | 
			
		||||
	var sourceBase64 string
 | 
			
		||||
	var targetBase64 string
 | 
			
		||||
	imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with download image: ", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		sourceBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
			
		||||
	}
 | 
			
		||||
	imageData, err = utils.DownloadImage(task.ImgArr[1], "")
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		logger.Error("error with download image: ", err)
 | 
			
		||||
	} else {
 | 
			
		||||
		targetBase64 = "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	body := gin.H{
 | 
			
		||||
		"sourceBase64": sourceBase64,
 | 
			
		||||
		"targetBase64": targetBase64,
 | 
			
		||||
		"accountFilter": gin.H{
 | 
			
		||||
			"instanceId": "",
 | 
			
		||||
		},
 | 
			
		||||
		"state": "",
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.SetTimeout(time.Minute).R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *PlusClient) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::upsample::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := c.client.R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *PlusClient) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]string{
 | 
			
		||||
		"customId": fmt.Sprintf("MJ::JOB::variation::%d::%s", task.Index, task.MessageHash),
 | 
			
		||||
		"taskId":   task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj-%s/mj/submit/action", c.apiURL, c.Config.Mode)
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *PlusClient) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
 | 
			
		||||
	var res QueryRes
 | 
			
		||||
	r, err := c.client.R().SetHeader("Authorization", "Bearer "+c.Config.ApiKey).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return QueryRes{}, errors.New("error status:" + r.Status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Client = &PlusClient{}
 | 
			
		||||
							
								
								
									
										227
									
								
								api/service/mj/pool.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										227
									
								
								api/service/mj/pool.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,227 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"geekai/service/oss"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ServicePool Mj service pool
 | 
			
		||||
type ServicePool struct {
 | 
			
		||||
	services        []*Service
 | 
			
		||||
	taskQueue       *store.RedisQueue
 | 
			
		||||
	notifyQueue     *store.RedisQueue
 | 
			
		||||
	db              *gorm.DB
 | 
			
		||||
	uploaderManager *oss.UploaderManager
 | 
			
		||||
	Clients         *types.LMap[uint, *types.WsClient] // UserId => Client
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var logger = logger2.GetLogger()
 | 
			
		||||
 | 
			
		||||
func NewServicePool(db *gorm.DB, redisCli *redis.Client, manager *oss.UploaderManager) *ServicePool {
 | 
			
		||||
	services := make([]*Service, 0)
 | 
			
		||||
	taskQueue := store.NewRedisQueue("MidJourney_Task_Queue", redisCli)
 | 
			
		||||
	notifyQueue := store.NewRedisQueue("MidJourney_Notify_Queue", redisCli)
 | 
			
		||||
	return &ServicePool{
 | 
			
		||||
		taskQueue:       taskQueue,
 | 
			
		||||
		notifyQueue:     notifyQueue,
 | 
			
		||||
		services:        services,
 | 
			
		||||
		uploaderManager: manager,
 | 
			
		||||
		db:              db,
 | 
			
		||||
		Clients:         types.NewLMap[uint, *types.WsClient](),
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) InitServices(plusConfigs []types.MjPlusConfig, proxyConfigs []types.MjProxyConfig) {
 | 
			
		||||
	// stop old service
 | 
			
		||||
	for _, s := range p.services {
 | 
			
		||||
		s.Stop()
 | 
			
		||||
	}
 | 
			
		||||
	p.services = make([]*Service, 0)
 | 
			
		||||
 | 
			
		||||
	for k, config := range plusConfigs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		cli := NewPlusClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-plus-service-%d", k)
 | 
			
		||||
		plusService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			plusService.Run()
 | 
			
		||||
		}()
 | 
			
		||||
		p.services = append(p.services, plusService)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// for mid-journey proxy
 | 
			
		||||
	for k, config := range proxyConfigs {
 | 
			
		||||
		if config.Enabled == false {
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		cli := NewProxyClient(config)
 | 
			
		||||
		name := fmt.Sprintf("mj-proxy-service-%d", k)
 | 
			
		||||
		proxyService := NewService(name, p.taskQueue, p.notifyQueue, p.db, cli)
 | 
			
		||||
		go func() {
 | 
			
		||||
			proxyService.Run()
 | 
			
		||||
		}()
 | 
			
		||||
		p.services = append(p.services, proxyService)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) CheckTaskNotify() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		for {
 | 
			
		||||
			var message sd.NotifyMessage
 | 
			
		||||
			err := p.notifyQueue.LPop(&message)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			cli := p.Clients.Get(uint(message.UserId))
 | 
			
		||||
			if cli == nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
			err = cli.Send([]byte(message.Message))
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) DownloadImages() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.MidJourneyJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := p.db.Where("img_url = ? AND progress = ?", "", 100).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// download images
 | 
			
		||||
			for _, v := range items {
 | 
			
		||||
				if v.OrgURL == "" {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				logger.Infof("try to download image: %s", v.OrgURL)
 | 
			
		||||
				mjService := p.getService(v.ChannelId)
 | 
			
		||||
				if mjService == nil {
 | 
			
		||||
					logger.Errorf("Invalid task: %+v", v)
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				task, _ := mjService.Client.QueryTask(v.TaskId)
 | 
			
		||||
				if len(task.Buttons) > 0 {
 | 
			
		||||
					v.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
				}
 | 
			
		||||
				// 如果是返回的是 discord 图片地址,则使用代理下载
 | 
			
		||||
				proxy := false
 | 
			
		||||
				if strings.HasPrefix(v.OrgURL, "https://cdn.discordapp.com") {
 | 
			
		||||
					proxy = true
 | 
			
		||||
				}
 | 
			
		||||
				imgURL, err := p.uploaderManager.GetUploadHandler().PutImg(v.OrgURL, proxy)
 | 
			
		||||
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					logger.Errorf("error with download image %s, %v", v.OrgURL, err)
 | 
			
		||||
					continue
 | 
			
		||||
				} else {
 | 
			
		||||
					logger.Infof("download image %s successfully.", v.OrgURL)
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				v.ImgURL = imgURL
 | 
			
		||||
				p.db.Updates(&v)
 | 
			
		||||
 | 
			
		||||
				cli := p.Clients.Get(uint(v.UserId))
 | 
			
		||||
				if cli == nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
				err = cli.Send([]byte(sd.Finished))
 | 
			
		||||
				if err != nil {
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 5)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// PushTask push a new mj task in to task queue
 | 
			
		||||
func (p *ServicePool) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Debugf("add a new MidJourney task to the task list: %+v", task)
 | 
			
		||||
	p.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// HasAvailableService check if it has available mj service in pool
 | 
			
		||||
func (p *ServicePool) HasAvailableService() bool {
 | 
			
		||||
	return len(p.services) > 0
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SyncTaskProgress 异步拉取任务
 | 
			
		||||
func (p *ServicePool) SyncTaskProgress() {
 | 
			
		||||
	go func() {
 | 
			
		||||
		var items []model.MidJourneyJob
 | 
			
		||||
		for {
 | 
			
		||||
			res := p.db.Where("progress < ?", 100).Find(&items)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				continue
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			for _, job := range items {
 | 
			
		||||
				// 失败或者 30 分钟还没完成的任务删除并退回算力
 | 
			
		||||
				if time.Now().Sub(job.CreatedAt) > time.Minute*30 || job.Progress == -1 {
 | 
			
		||||
					p.db.Delete(&job)
 | 
			
		||||
					// 退回算力
 | 
			
		||||
					tx := p.db.Model(&model.User{}).Where("id = ?", job.UserId).UpdateColumn("power", gorm.Expr("power + ?", job.Power))
 | 
			
		||||
					if tx.Error == nil && tx.RowsAffected > 0 {
 | 
			
		||||
						var user model.User
 | 
			
		||||
						p.db.Where("id = ?", job.UserId).First(&user)
 | 
			
		||||
						p.db.Create(&model.PowerLog{
 | 
			
		||||
							UserId:    user.Id,
 | 
			
		||||
							Username:  user.Username,
 | 
			
		||||
							Type:      types.PowerConsume,
 | 
			
		||||
							Amount:    job.Power,
 | 
			
		||||
							Balance:   user.Power + job.Power,
 | 
			
		||||
							Mark:      types.PowerAdd,
 | 
			
		||||
							Model:     "mid-journey",
 | 
			
		||||
							Remark:    fmt.Sprintf("绘画任务失败,退回算力。任务ID:%s", job.TaskId),
 | 
			
		||||
							CreatedAt: time.Now(),
 | 
			
		||||
						})
 | 
			
		||||
					}
 | 
			
		||||
					continue
 | 
			
		||||
				}
 | 
			
		||||
 | 
			
		||||
				if servicePlus := p.getService(job.ChannelId); servicePlus != nil {
 | 
			
		||||
					_ = servicePlus.Notify(job)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			time.Sleep(time.Second * 10)
 | 
			
		||||
		}
 | 
			
		||||
	}()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (p *ServicePool) getService(name string) *Service {
 | 
			
		||||
	for _, s := range p.services {
 | 
			
		||||
		if s.Name == name {
 | 
			
		||||
			return s
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										185
									
								
								api/service/mj/proxy_client.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										185
									
								
								api/service/mj/proxy_client.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,185 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/imroc/req/v3"
 | 
			
		||||
	"io"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// ProxyClient MidJourney Proxy Client
 | 
			
		||||
type ProxyClient struct {
 | 
			
		||||
	Config types.MjProxyConfig
 | 
			
		||||
	apiURL string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewProxyClient(config types.MjProxyConfig) *ProxyClient {
 | 
			
		||||
	return &ProxyClient{Config: config, apiURL: config.ApiURL}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ProxyClient) Imagine(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/imagine", c.apiURL)
 | 
			
		||||
	prompt := fmt.Sprintf("%s %s", task.Prompt, task.Params)
 | 
			
		||||
	if task.NegPrompt != "" {
 | 
			
		||||
		prompt += fmt.Sprintf(" --no %s", task.NegPrompt)
 | 
			
		||||
	}
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		Prompt:      prompt,
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		imageData, err := utils.DownloadImage(task.ImgArr[0], "")
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("error with download image: ", err)
 | 
			
		||||
		} else {
 | 
			
		||||
			body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	}
 | 
			
		||||
	logger.Info("API URL: ", apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		errStr, _ := io.ReadAll(r.Body)
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s,%v", errRes.Error.Message, string(errStr))
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Blend 融图
 | 
			
		||||
func (c *ProxyClient) Blend(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/blend", c.apiURL)
 | 
			
		||||
	body := ImageReq{
 | 
			
		||||
		Dimensions:  "SQUARE",
 | 
			
		||||
		Base64Array: make([]string, 0),
 | 
			
		||||
	}
 | 
			
		||||
	// 生成图片 Base64 编码
 | 
			
		||||
	if len(task.ImgArr) > 0 {
 | 
			
		||||
		for _, imgURL := range task.ImgArr {
 | 
			
		||||
			imageData, err := utils.DownloadImage(imgURL, "")
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
			} else {
 | 
			
		||||
				body.Base64Array = append(body.Base64Array, "data:image/png;base64,"+base64.StdEncoding.EncodeToString(imageData))
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API %s 出错:%v", apiURL, err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// SwapFace 换脸
 | 
			
		||||
func (c *ProxyClient) SwapFace(_ types.MjTask) (ImageRes, error) {
 | 
			
		||||
	return ImageRes{}, errors.New("MidJourney-Proxy暂未实现该功能,请使用 MidJourney-Plus")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Upscale 放大指定的图片
 | 
			
		||||
func (c *ProxyClient) Upscale(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]interface{}{
 | 
			
		||||
		"action": "UPSCALE",
 | 
			
		||||
		"index":  task.Index,
 | 
			
		||||
		"taskId": task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Variation  以指定的图片的视角进行变换再创作,注意需要在对应的频道中关闭 Remix 变换,否则 Variation 指令将不会生效
 | 
			
		||||
func (c *ProxyClient) Variation(task types.MjTask) (ImageRes, error) {
 | 
			
		||||
	body := map[string]interface{}{
 | 
			
		||||
		"action": "VARIATION",
 | 
			
		||||
		"index":  task.Index,
 | 
			
		||||
		"taskId": task.MessageId,
 | 
			
		||||
	}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/submit/change", c.apiURL)
 | 
			
		||||
	var res ImageRes
 | 
			
		||||
	var errRes ErrRes
 | 
			
		||||
	r, err := req.C().R().
 | 
			
		||||
		SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetBody(body).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		SetErrorResult(&errRes).
 | 
			
		||||
		Post(apiURL)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("请求 API 出错:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return ImageRes{}, fmt.Errorf("API 返回错误:%s", errRes.Error.Message)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (c *ProxyClient) QueryTask(taskId string) (QueryRes, error) {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/mj/task/%s/fetch", c.apiURL, taskId)
 | 
			
		||||
	var res QueryRes
 | 
			
		||||
	r, err := req.C().R().SetHeader("mj-api-secret", c.Config.ApiKey).
 | 
			
		||||
		SetSuccessResult(&res).
 | 
			
		||||
		Get(apiURL)
 | 
			
		||||
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return QueryRes{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.IsErrorState() {
 | 
			
		||||
		return QueryRes{}, errors.New("error status:" + r.Status)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Client = &ProxyClient{}
 | 
			
		||||
@@ -1,249 +1,198 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/service/oss"
 | 
			
		||||
	"chatplus/store"
 | 
			
		||||
	"chatplus/store/model"
 | 
			
		||||
	"chatplus/store/vo"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/go-redis/redis/v8"
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/service"
 | 
			
		||||
	"geekai/service/sd"
 | 
			
		||||
	"geekai/store"
 | 
			
		||||
	"geekai/store/model"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"gorm.io/gorm"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
// MJ 绘画服务
 | 
			
		||||
 | 
			
		||||
const RunningJobKey = "MidJourney_Running_Job"
 | 
			
		||||
 | 
			
		||||
// Service MJ 绘画服务
 | 
			
		||||
type Service struct {
 | 
			
		||||
	client        *Client // MJ 客户端
 | 
			
		||||
	taskQueue     *store.RedisQueue
 | 
			
		||||
	redis         *redis.Client
 | 
			
		||||
	db            *gorm.DB
 | 
			
		||||
	uploadManager *oss.UploaderManager
 | 
			
		||||
	Clients       *types.LMap[string, *types.WsClient] // MJ 绘画页面 websocket 连接池,用户推送绘画消息
 | 
			
		||||
	ChatClients   *types.LMap[string, *types.WsClient] // 聊天页面 websocket 连接池,用于推送绘画消息
 | 
			
		||||
	proxyURL      string
 | 
			
		||||
	Name        string // service Name
 | 
			
		||||
	Client      Client // MJ Client
 | 
			
		||||
	taskQueue   *store.RedisQueue
 | 
			
		||||
	notifyQueue *store.RedisQueue
 | 
			
		||||
	db          *gorm.DB
 | 
			
		||||
	running     bool
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewService(redisCli *redis.Client, db *gorm.DB, client *Client, manager *oss.UploaderManager, config *types.AppConfig) *Service {
 | 
			
		||||
func NewService(name string, taskQueue *store.RedisQueue, notifyQueue *store.RedisQueue, db *gorm.DB, cli Client) *Service {
 | 
			
		||||
	return &Service{
 | 
			
		||||
		redis:         redisCli,
 | 
			
		||||
		db:            db,
 | 
			
		||||
		taskQueue:     store.NewRedisQueue("MidJourney_Task_Queue", redisCli),
 | 
			
		||||
		client:        client,
 | 
			
		||||
		uploadManager: manager,
 | 
			
		||||
		Clients:       types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		ChatClients:   types.NewLMap[string, *types.WsClient](),
 | 
			
		||||
		proxyURL:      config.ProxyURL,
 | 
			
		||||
		Name:        name,
 | 
			
		||||
		db:          db,
 | 
			
		||||
		taskQueue:   taskQueue,
 | 
			
		||||
		notifyQueue: notifyQueue,
 | 
			
		||||
		Client:      cli,
 | 
			
		||||
		running:     true,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Run() {
 | 
			
		||||
	logger.Info("Starting MidJourney job consumer.")
 | 
			
		||||
	ctx := context.Background()
 | 
			
		||||
	for {
 | 
			
		||||
		_, err := s.redis.Get(ctx, RunningJobKey).Result()
 | 
			
		||||
		if err == nil { // 队列串行执行
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
	logger.Infof("Starting MidJourney job consumer for %s", s.Name)
 | 
			
		||||
	for s.running {
 | 
			
		||||
		var task types.MjTask
 | 
			
		||||
		err = s.taskQueue.LPop(&task)
 | 
			
		||||
		err := s.taskQueue.LPop(&task)
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Errorf("taking task with error: %v", err)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("Consuming Task: %+v", task)
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			err = s.client.Imagine(task.Prompt)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskUpscale:
 | 
			
		||||
			err = s.client.Upscale(task.Index, task.MessageId, task.MessageHash)
 | 
			
		||||
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			err = s.client.Variation(task.Index, task.MessageId, task.MessageHash)
 | 
			
		||||
		}
 | 
			
		||||
		if err != nil {
 | 
			
		||||
			logger.Error("绘画任务执行失败:", err)
 | 
			
		||||
			if task.RetryCount <= 5 {
 | 
			
		||||
				s.taskQueue.RPush(task)
 | 
			
		||||
			}
 | 
			
		||||
			task.RetryCount += 1
 | 
			
		||||
			time.Sleep(time.Second * 3)
 | 
			
		||||
		//  如果配置了多个中转平台的 API KEY
 | 
			
		||||
		// U,V 操作必须和 Image 操作属于同一个平台,否则找不到关联任务,需重新放回任务列表
 | 
			
		||||
		if task.ChannelId != "" && task.ChannelId != s.Name {
 | 
			
		||||
			logger.Debugf("handle other service task, name: %s, channel_id: %s, drop it.", s.Name, task.ChannelId)
 | 
			
		||||
			s.taskQueue.RPush(task)
 | 
			
		||||
			time.Sleep(time.Second)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// 更新任务的执行状态
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{}).Where("id = ?", task.Id).UpdateColumn("started", true)
 | 
			
		||||
		// 锁定任务执行通道,直到任务超时(5分钟)
 | 
			
		||||
		s.redis.Set(ctx, RunningJobKey, utils.JsonEncode(task), time.Minute*5)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
		// translate prompt
 | 
			
		||||
		if utils.HasChinese(task.Prompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.Prompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.Prompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
		// translate negative prompt
 | 
			
		||||
		if task.NegPrompt != "" && utils.HasChinese(task.NegPrompt) {
 | 
			
		||||
			content, err := utils.OpenAIRequest(s.db, fmt.Sprintf(service.RewritePromptTemplate, task.NegPrompt))
 | 
			
		||||
			if err == nil {
 | 
			
		||||
				task.NegPrompt = content
 | 
			
		||||
			} else {
 | 
			
		||||
				logger.Warnf("error with translate prompt: %v", err)
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
func (s *Service) PushTask(task types.MjTask) {
 | 
			
		||||
	logger.Infof("add a new MidJourney Task: %+v", task)
 | 
			
		||||
	s.taskQueue.RPush(task)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(data CBReq) {
 | 
			
		||||
	taskString, err := s.redis.Get(context.Background(), RunningJobKey).Result()
 | 
			
		||||
	if err != nil { // 过期任务,丢弃
 | 
			
		||||
		logger.Warn("任务已过期:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var task types.MjTask
 | 
			
		||||
	err = utils.JsonDecode(taskString, &task)
 | 
			
		||||
	if err != nil { // 非标准任务,丢弃
 | 
			
		||||
		logger.Warn("任务解析失败:", err)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var job model.MidJourneyJob
 | 
			
		||||
	res := s.db.Where("message_id = ?", data.MessageId).First(&job)
 | 
			
		||||
	if res.Error == nil && data.Status == Finished {
 | 
			
		||||
		logger.Warn("重复消息:", data.MessageId)
 | 
			
		||||
		return
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if task.Src == types.TaskSrcImg { // 绘画任务
 | 
			
		||||
		var job model.MidJourneyJob
 | 
			
		||||
		res := s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Warn("非法任务:", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
		job.MessageId = data.MessageId
 | 
			
		||||
		job.ReferenceId = data.ReferenceId
 | 
			
		||||
		job.Progress = data.Progress
 | 
			
		||||
		job.Prompt = data.Prompt
 | 
			
		||||
		job.Hash = data.Image.Hash
 | 
			
		||||
 | 
			
		||||
		// 任务完成,将最终的图片下载下来
 | 
			
		||||
		if data.Progress == 100 {
 | 
			
		||||
			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download img: ", err.Error())
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
		} else {
 | 
			
		||||
			// 临时图片直接保存,访问的时候使用代理进行转发
 | 
			
		||||
			job.ImgURL = data.Image.URL
 | 
			
		||||
		}
 | 
			
		||||
		res = s.db.Updates(&job)
 | 
			
		||||
		if res.Error != nil {
 | 
			
		||||
			logger.Error("error with update job: ", res.Error)
 | 
			
		||||
			return
 | 
			
		||||
		tx := s.db.Where("id = ?", task.Id).First(&job)
 | 
			
		||||
		if tx.Error != nil {
 | 
			
		||||
			logger.Error("任务不存在,任务ID:", task.TaskId)
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		var jobVo vo.MidJourneyJob
 | 
			
		||||
		err := utils.CopyObject(job, &jobVo)
 | 
			
		||||
		if err == nil {
 | 
			
		||||
			if data.Progress < 100 {
 | 
			
		||||
				image, err := utils.DownloadImage(jobVo.ImgURL, s.proxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					jobVo.ImgURL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// 推送任务到前端
 | 
			
		||||
			client := s.Clients.Get(task.SessionId)
 | 
			
		||||
			if client != nil {
 | 
			
		||||
				utils.ReplyChunkMessage(client, jobVo)
 | 
			
		||||
			}
 | 
			
		||||
		logger.Infof("%s handle a new MidJourney task: %+v", s.Name, task)
 | 
			
		||||
		var res ImageRes
 | 
			
		||||
		switch task.Type {
 | 
			
		||||
		case types.TaskImage:
 | 
			
		||||
			res, err = s.Client.Imagine(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskUpscale:
 | 
			
		||||
			res, err = s.Client.Upscale(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskVariation:
 | 
			
		||||
			res, err = s.Client.Variation(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskBlend:
 | 
			
		||||
			res, err = s.Client.Blend(task)
 | 
			
		||||
			break
 | 
			
		||||
		case types.TaskSwapFace:
 | 
			
		||||
			res, err = s.Client.SwapFace(task)
 | 
			
		||||
			break
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
	} else if task.Src == types.TaskSrcChat { // 聊天任务
 | 
			
		||||
		wsClient := s.ChatClients.Get(task.SessionId)
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
				content := fmt.Sprintf("**%s** 任务执行成功,正在从 MidJourney 服务器下载图片,请稍后...", data.Prompt)
 | 
			
		||||
				utils.ReplyMessage(wsClient, content)
 | 
			
		||||
			}
 | 
			
		||||
			// download image
 | 
			
		||||
			imgURL, err := s.uploadManager.GetUploadHandler().PutImg(data.Image.URL, true)
 | 
			
		||||
			if err != nil {
 | 
			
		||||
				logger.Error("error with download image: ", err)
 | 
			
		||||
				if wsClient != nil && data.ReferenceId != "" {
 | 
			
		||||
					content := fmt.Sprintf("**%s** 图片下载失败:%s", data.Prompt, err.Error())
 | 
			
		||||
					utils.ReplyMessage(wsClient, content)
 | 
			
		||||
				}
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			tx := s.db.Begin()
 | 
			
		||||
			data.Image.URL = imgURL
 | 
			
		||||
			message := model.HistoryMessage{
 | 
			
		||||
				UserId:     uint(task.UserId),
 | 
			
		||||
				ChatId:     task.ChatId,
 | 
			
		||||
				RoleId:     uint(task.RoleId),
 | 
			
		||||
				Type:       types.MjMsg,
 | 
			
		||||
				Icon:       task.Icon,
 | 
			
		||||
				Content:    utils.JsonEncode(data),
 | 
			
		||||
				Tokens:     0,
 | 
			
		||||
				UseContext: false,
 | 
			
		||||
			}
 | 
			
		||||
			res = tx.Create(&message)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("error with update database: ", err)
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// save the job
 | 
			
		||||
			job.UserId = task.UserId
 | 
			
		||||
			job.Type = task.Type.String()
 | 
			
		||||
			job.MessageId = data.MessageId
 | 
			
		||||
			job.ReferenceId = data.ReferenceId
 | 
			
		||||
			job.Prompt = data.Prompt
 | 
			
		||||
			job.ImgURL = imgURL
 | 
			
		||||
			job.Progress = data.Progress
 | 
			
		||||
			job.Hash = data.Image.Hash
 | 
			
		||||
			job.CreatedAt = time.Now()
 | 
			
		||||
			res = tx.Create(&job)
 | 
			
		||||
			if res.Error != nil {
 | 
			
		||||
				logger.Error("error with update database: ", err)
 | 
			
		||||
				tx.Rollback()
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			tx.Commit()
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if wsClient == nil { // 客户端断线,则丢弃
 | 
			
		||||
			logger.Errorf("Client is offline: %+v", data)
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		if data.Status == Finished {
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsEnd})
 | 
			
		||||
			// 本次绘画完毕,移除客户端
 | 
			
		||||
			s.ChatClients.Delete(task.SessionId)
 | 
			
		||||
		} else {
 | 
			
		||||
			// 使用代理临时转发图片
 | 
			
		||||
			if data.Image.URL != "" {
 | 
			
		||||
				image, err := utils.DownloadImage(data.Image.URL, s.proxyURL)
 | 
			
		||||
				if err == nil {
 | 
			
		||||
					data.Image.URL = "data:image/png;base64," + base64.StdEncoding.EncodeToString(image)
 | 
			
		||||
				}
 | 
			
		||||
			}
 | 
			
		||||
			utils.ReplyChunkMessage(wsClient, types.WsMessage{Type: types.WsMjImg, Content: data})
 | 
			
		||||
		if err != nil || (res.Code != 1 && res.Code != 22) {
 | 
			
		||||
			errMsg := fmt.Sprintf("%v,%s", err, res.Description)
 | 
			
		||||
			logger.Error("绘画任务执行失败:", errMsg)
 | 
			
		||||
			job.Progress = -1
 | 
			
		||||
			job.ErrMsg = errMsg
 | 
			
		||||
			// update the task progress
 | 
			
		||||
			s.db.Updates(&job)
 | 
			
		||||
			// 任务失败,通知前端
 | 
			
		||||
			s.notifyQueue.RPush(sd.NotifyMessage{UserId: task.UserId, JobId: int(job.Id), Message: sd.Failed})
 | 
			
		||||
			continue
 | 
			
		||||
		}
 | 
			
		||||
		logger.Infof("任务提交成功:%+v", res)
 | 
			
		||||
		// 更新任务 ID/频道
 | 
			
		||||
		job.TaskId = res.Result
 | 
			
		||||
		job.MessageId = res.Result
 | 
			
		||||
		job.ChannelId = s.Name
 | 
			
		||||
		s.db.Updates(&job)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 更新用户剩余绘图次数
 | 
			
		||||
	// TODO: 放大图片是否需要消耗绘图次数?
 | 
			
		||||
	if data.Status == Finished {
 | 
			
		||||
		s.db.Model(&model.User{}).Where("id = ?", task.UserId).UpdateColumn("img_calls", gorm.Expr("img_calls - ?", 1))
 | 
			
		||||
		// 解除任务锁定
 | 
			
		||||
		s.redis.Del(context.Background(), RunningJobKey)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Stop() {
 | 
			
		||||
	s.running = false
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	Id          string      `json:"id"`
 | 
			
		||||
	Action      string      `json:"action"`
 | 
			
		||||
	Status      string      `json:"status"`
 | 
			
		||||
	Prompt      string      `json:"prompt"`
 | 
			
		||||
	PromptEn    string      `json:"promptEn"`
 | 
			
		||||
	Description string      `json:"description"`
 | 
			
		||||
	SubmitTime  int64       `json:"submitTime"`
 | 
			
		||||
	StartTime   int64       `json:"startTime"`
 | 
			
		||||
	FinishTime  int64       `json:"finishTime"`
 | 
			
		||||
	Progress    string      `json:"progress"`
 | 
			
		||||
	ImageUrl    string      `json:"imageUrl"`
 | 
			
		||||
	FailReason  interface{} `json:"failReason"`
 | 
			
		||||
	Properties  struct {
 | 
			
		||||
		FinalPrompt string `json:"finalPrompt"`
 | 
			
		||||
	} `json:"properties"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s *Service) Notify(job model.MidJourneyJob) error {
 | 
			
		||||
	task, err := s.Client.QueryTask(job.TaskId)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// 任务执行失败了
 | 
			
		||||
	if task.FailReason != "" {
 | 
			
		||||
		s.db.Model(&model.MidJourneyJob{Id: job.Id}).UpdateColumns(map[string]interface{}{
 | 
			
		||||
			"progress": -1,
 | 
			
		||||
			"err_msg":  task.FailReason,
 | 
			
		||||
		})
 | 
			
		||||
		s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: sd.Failed})
 | 
			
		||||
		return fmt.Errorf("task failed: %v", task.FailReason)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if len(task.Buttons) > 0 {
 | 
			
		||||
		job.Hash = GetImageHash(task.Buttons[0].CustomId)
 | 
			
		||||
	}
 | 
			
		||||
	oldProgress := job.Progress
 | 
			
		||||
	job.Progress = utils.IntValue(strings.Replace(task.Progress, "%", "", 1), 0)
 | 
			
		||||
	job.Prompt = task.PromptEn
 | 
			
		||||
	if task.ImageUrl != "" {
 | 
			
		||||
		job.OrgURL = task.ImageUrl
 | 
			
		||||
	}
 | 
			
		||||
	tx := s.db.Updates(&job)
 | 
			
		||||
	if tx.Error != nil {
 | 
			
		||||
		return fmt.Errorf("error with update database: %v", tx.Error)
 | 
			
		||||
	}
 | 
			
		||||
	// 通知前端更新任务进度
 | 
			
		||||
	if oldProgress != job.Progress {
 | 
			
		||||
		message := sd.Running
 | 
			
		||||
		if job.Progress == 100 {
 | 
			
		||||
			message = sd.Finished
 | 
			
		||||
		}
 | 
			
		||||
		s.notifyQueue.RPush(sd.NotifyMessage{UserId: job.UserId, JobId: int(job.Id), Message: message})
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func GetImageHash(action string) string {
 | 
			
		||||
	split := strings.Split(action, "::")
 | 
			
		||||
	if len(split) > 5 {
 | 
			
		||||
		return split[4]
 | 
			
		||||
	}
 | 
			
		||||
	return split[len(split)-1]
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,34 +0,0 @@
 | 
			
		||||
package mj
 | 
			
		||||
 | 
			
		||||
const (
 | 
			
		||||
	ApplicationID string = "936929561302675456"
 | 
			
		||||
	SessionID     string = "ea8816d857ba9ae2f74c59ae1a953afe"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type InteractionsRequest struct {
 | 
			
		||||
	Type          int            `json:"type"`
 | 
			
		||||
	ApplicationID string         `json:"application_id"`
 | 
			
		||||
	MessageFlags  *int           `json:"message_flags,omitempty"`
 | 
			
		||||
	MessageID     *string        `json:"message_id,omitempty"`
 | 
			
		||||
	GuildID       string         `json:"guild_id"`
 | 
			
		||||
	ChannelID     string         `json:"channel_id"`
 | 
			
		||||
	SessionID     string         `json:"session_id"`
 | 
			
		||||
	Data          map[string]any `json:"data"`
 | 
			
		||||
	Nonce         string         `json:"nonce,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type InteractionsResult struct {
 | 
			
		||||
	Code    int `json:"code"`
 | 
			
		||||
	Message string
 | 
			
		||||
	Error   map[string]any
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type CBReq struct {
 | 
			
		||||
	MessageId   string     `json:"message_id"`
 | 
			
		||||
	ReferenceId string     `json:"reference_id"`
 | 
			
		||||
	Image       Image      `json:"image"`
 | 
			
		||||
	Content     string     `json:"content"`
 | 
			
		||||
	Prompt      string     `json:"prompt"`
 | 
			
		||||
	Status      TaskStatus `json:"status"`
 | 
			
		||||
	Progress    int        `json:"progress"`
 | 
			
		||||
}
 | 
			
		||||
@@ -1,15 +1,25 @@
 | 
			
		||||
package oss
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/aliyun/aliyun-oss-go-sdk/oss"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/aliyun/aliyun-oss-go-sdk/oss"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type AliYunOss struct {
 | 
			
		||||
@@ -32,6 +42,10 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
 | 
			
		||||
		return nil, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if config.SubDir == "" {
 | 
			
		||||
		config.SubDir = "gpt"
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return &AliYunOss{
 | 
			
		||||
		config:   config,
 | 
			
		||||
		bucket:   bucket,
 | 
			
		||||
@@ -40,28 +54,34 @@ func NewAliYunOss(appConfig *types.AppConfig) (*AliYunOss, error) {
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (string, error) {
 | 
			
		||||
func (s AliYunOss) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
	// 解析表单
 | 
			
		||||
	file, err := ctx.FormFile(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
		return File{}, err
 | 
			
		||||
	}
 | 
			
		||||
	// 打开上传文件
 | 
			
		||||
	src, err := file.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
		return File{}, err
 | 
			
		||||
	}
 | 
			
		||||
	defer src.Close()
 | 
			
		||||
 | 
			
		||||
	fileExt := filepath.Ext(file.Filename)
 | 
			
		||||
	objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	// 上传文件
 | 
			
		||||
	err = s.bucket.PutObject(objectKey, src)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
		return File{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil
 | 
			
		||||
	return File{
 | 
			
		||||
		Name:   file.Filename,
 | 
			
		||||
		ObjKey: objectKey,
 | 
			
		||||
		URL:    fmt.Sprintf("%s/%s", s.config.Domain, objectKey),
 | 
			
		||||
		Ext:    fileExt,
 | 
			
		||||
		Size:   file.Size,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
@@ -79,19 +99,39 @@ func (s AliYunOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	fileExt := filepath.Ext(parse.Path)
 | 
			
		||||
	objectKey := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	fileExt := utils.GetImgExt(parse.Path)
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
	err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("https://%s.%s/%s", s.config.Bucket, s.config.Endpoint, objectKey), nil
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AliYunOss) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
	err = s.bucket.PutObject(objectKey, bytes.NewReader(imageData))
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, objectKey), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s AliYunOss) Delete(fileURL string) error {
 | 
			
		||||
	objectName := filepath.Base(fileURL)
 | 
			
		||||
	return s.bucket.DeleteObject(objectName)
 | 
			
		||||
	var objectKey string
 | 
			
		||||
	if strings.HasPrefix(fileURL, "http") {
 | 
			
		||||
		filename := filepath.Base(fileURL)
 | 
			
		||||
		objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
 | 
			
		||||
	} else {
 | 
			
		||||
		objectKey = fileURL
 | 
			
		||||
	}
 | 
			
		||||
	return s.bucket.DeleteObject(objectKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Uploader = AliYunOss{}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,17 @@
 | 
			
		||||
package oss
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"os"
 | 
			
		||||
@@ -23,23 +31,30 @@ func NewLocalStorage(config *types.AppConfig) LocalStorage {
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (string, error) {
 | 
			
		||||
func (s LocalStorage) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
	file, err := ctx.FormFile(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with get form: %v", err)
 | 
			
		||||
		return File{}, fmt.Errorf("error with get form: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	filePath, err := utils.GenUploadPath(s.config.BasePath, file.Filename)
 | 
			
		||||
	path, err := utils.GenUploadPath(s.config.BasePath, file.Filename, false)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with generate filename: %s", err.Error())
 | 
			
		||||
		return File{}, fmt.Errorf("error with generate filename: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
	// 将文件保存到指定路径
 | 
			
		||||
	err = ctx.SaveUploadedFile(file, filePath)
 | 
			
		||||
	err = ctx.SaveUploadedFile(file, path)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with save upload file: %s", err.Error())
 | 
			
		||||
		return File{}, fmt.Errorf("error with save upload file: %s", err.Error())
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
 | 
			
		||||
	ext := filepath.Ext(file.Filename)
 | 
			
		||||
	return File{
 | 
			
		||||
		Name:   file.Filename,
 | 
			
		||||
		ObjKey: path,
 | 
			
		||||
		URL:    utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, path),
 | 
			
		||||
		Ext:    ext,
 | 
			
		||||
		Size:   file.Size,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
@@ -48,7 +63,7 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	filename := filepath.Base(parse.Path)
 | 
			
		||||
	filePath, err := utils.GenUploadPath(s.config.BasePath, filename)
 | 
			
		||||
	filePath, err := utils.GenUploadPath(s.config.BasePath, filename, true)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with generate image dir: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
@@ -65,7 +80,24 @@ func (s LocalStorage) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s LocalStorage) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	filePath, err := utils.GenUploadPath(s.config.BasePath, "", true)
 | 
			
		||||
	err = os.WriteFile(filePath, imageData, 0644)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error writing to file:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return utils.GenUploadUrl(s.config.BasePath, s.config.BaseURL, filePath), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s LocalStorage) Delete(fileURL string) error {
 | 
			
		||||
	if _, err := os.Stat(fileURL); err == nil {
 | 
			
		||||
		return os.Remove(fileURL)
 | 
			
		||||
	}
 | 
			
		||||
	filePath := strings.Replace(fileURL, s.config.BaseURL, s.config.BasePath, 1)
 | 
			
		||||
	return os.Remove(filePath)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,17 +1,26 @@
 | 
			
		||||
package oss
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/minio/minio-go/v7"
 | 
			
		||||
	"github.com/minio/minio-go/v7/pkg/credentials"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/minio/minio-go/v7"
 | 
			
		||||
	"github.com/minio/minio-go/v7/pkg/credentials"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type MiniOss struct {
 | 
			
		||||
@@ -29,6 +38,9 @@ func NewMiniOss(appConfig *types.AppConfig) (MiniOss, error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return MiniOss{}, err
 | 
			
		||||
	}
 | 
			
		||||
	if config.SubDir == "" {
 | 
			
		||||
		config.SubDir = "gpt"
 | 
			
		||||
	}
 | 
			
		||||
	return MiniOss{config: config, client: minioClient, proxyURL: appConfig.ProxyURL}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
@@ -48,7 +60,7 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	fileExt := filepath.Ext(parse.Path)
 | 
			
		||||
	filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	info, err := s.client.PutObject(
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		s.config.Bucket,
 | 
			
		||||
@@ -62,33 +74,64 @@ func (s MiniOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s MiniOss) PutFile(ctx *gin.Context, name string) (string, error) {
 | 
			
		||||
func (s MiniOss) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
	file, err := ctx.FormFile(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with get form: %v", err)
 | 
			
		||||
		return File{}, fmt.Errorf("error with get form: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	// Open the uploaded file
 | 
			
		||||
	fileReader, err := file.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error opening file: %v", err)
 | 
			
		||||
		return File{}, fmt.Errorf("error opening file: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer fileReader.Close()
 | 
			
		||||
 | 
			
		||||
	fileExt := filepath.Ext(file.Filename)
 | 
			
		||||
	filename := fmt.Sprintf("%d%s", time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	fileExt := utils.GetImgExt(file.Filename)
 | 
			
		||||
	filename := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	info, err := s.client.PutObject(ctx, s.config.Bucket, filename, fileReader, file.Size, minio.PutObjectOptions{
 | 
			
		||||
		ContentType: file.Header.Get("Content-Type"),
 | 
			
		||||
	})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error uploading to MinIO: %v", err)
 | 
			
		||||
		return File{}, fmt.Errorf("error uploading to MinIO: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return File{
 | 
			
		||||
		Name:   file.Filename,
 | 
			
		||||
		ObjKey: info.Key,
 | 
			
		||||
		URL:    fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key),
 | 
			
		||||
		Ext:    fileExt,
 | 
			
		||||
		Size:   file.Size,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s MiniOss) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
 | 
			
		||||
	info, err := s.client.PutObject(
 | 
			
		||||
		context.Background(),
 | 
			
		||||
		s.config.Bucket,
 | 
			
		||||
		objectKey,
 | 
			
		||||
		strings.NewReader(string(imageData)),
 | 
			
		||||
		int64(len(imageData)),
 | 
			
		||||
		minio.PutObjectOptions{ContentType: "image/png"})
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%s/%s/%s", s.config.Domain, s.config.Bucket, info.Key), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s MiniOss) Delete(fileURL string) error {
 | 
			
		||||
	objectName := filepath.Base(fileURL)
 | 
			
		||||
	return s.client.RemoveObject(context.Background(), s.config.Bucket, objectName, minio.RemoveObjectOptions{})
 | 
			
		||||
	var objectKey string
 | 
			
		||||
	if strings.HasPrefix(fileURL, "http") {
 | 
			
		||||
		filename := filepath.Base(fileURL)
 | 
			
		||||
		objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
 | 
			
		||||
	} else {
 | 
			
		||||
		objectKey = fileURL
 | 
			
		||||
	}
 | 
			
		||||
	return s.client.RemoveObject(context.Background(), s.config.Bucket, objectKey, minio.RemoveObjectOptions{})
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Uploader = MiniOss{}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,17 +1,27 @@
 | 
			
		||||
package oss
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"bytes"
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"chatplus/utils"
 | 
			
		||||
	"context"
 | 
			
		||||
	"encoding/base64"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
 | 
			
		||||
	"github.com/gin-gonic/gin"
 | 
			
		||||
	"github.com/qiniu/go-sdk/v7/auth/qbox"
 | 
			
		||||
	"github.com/qiniu/go-sdk/v7/storage"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"path/filepath"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type QinNiuOss struct {
 | 
			
		||||
@@ -21,7 +31,6 @@ type QinNiuOss struct {
 | 
			
		||||
	uploader  *storage.FormUploader
 | 
			
		||||
	manager   *storage.BucketManager
 | 
			
		||||
	proxyURL  string
 | 
			
		||||
	dir       string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
 | 
			
		||||
@@ -38,6 +47,9 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
 | 
			
		||||
	putPolicy := storage.PutPolicy{
 | 
			
		||||
		Scope: config.Bucket,
 | 
			
		||||
	}
 | 
			
		||||
	if config.SubDir == "" {
 | 
			
		||||
		config.SubDir = "gpt"
 | 
			
		||||
	}
 | 
			
		||||
	return QinNiuOss{
 | 
			
		||||
		config:    config,
 | 
			
		||||
		mac:       mac,
 | 
			
		||||
@@ -45,34 +57,40 @@ func NewQiNiuOss(appConfig *types.AppConfig) QinNiuOss {
 | 
			
		||||
		uploader:  formUploader,
 | 
			
		||||
		manager:   storage.NewBucketManager(mac, &storeConfig),
 | 
			
		||||
		proxyURL:  appConfig.ProxyURL,
 | 
			
		||||
		dir:       "chatgpt-plus",
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (string, error) {
 | 
			
		||||
func (s QinNiuOss) PutFile(ctx *gin.Context, name string) (File, error) {
 | 
			
		||||
	// 解析表单
 | 
			
		||||
	file, err := ctx.FormFile(name)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
		return File{}, err
 | 
			
		||||
	}
 | 
			
		||||
	// 打开上传文件
 | 
			
		||||
	src, err := file.Open()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
		return File{}, err
 | 
			
		||||
	}
 | 
			
		||||
	defer src.Close()
 | 
			
		||||
 | 
			
		||||
	fileExt := filepath.Ext(file.Filename)
 | 
			
		||||
	key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	// 上传文件
 | 
			
		||||
	ret := storage.PutRet{}
 | 
			
		||||
	extra := storage.PutExtra{}
 | 
			
		||||
	err = s.uploader.Put(ctx, &ret, s.putPolicy.UploadToken(s.mac), key, src, file.Size, &extra)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
		return File{}, err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
 | 
			
		||||
	return File{
 | 
			
		||||
		Name:   file.Filename,
 | 
			
		||||
		ObjKey: key,
 | 
			
		||||
		URL:    fmt.Sprintf("%s/%s", s.config.Domain, ret.Key),
 | 
			
		||||
		Ext:    fileExt,
 | 
			
		||||
		Size:   file.Size,
 | 
			
		||||
	}, nil
 | 
			
		||||
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
@@ -90,8 +108,8 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error with parse image URL: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	fileExt := filepath.Ext(parse.Path)
 | 
			
		||||
	key := fmt.Sprintf("%s/%d%s", s.dir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	fileExt := utils.GetImgExt(parse.Path)
 | 
			
		||||
	key := fmt.Sprintf("%s/%d%s", s.config.SubDir, time.Now().UnixMicro(), fileExt)
 | 
			
		||||
	ret := storage.PutRet{}
 | 
			
		||||
	extra := storage.PutExtra{}
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
@@ -102,10 +120,32 @@ func (s QinNiuOss) PutImg(imageURL string, useProxy bool) (string, error) {
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s QinNiuOss) PutBase64(base64Img string) (string, error) {
 | 
			
		||||
	imageData, err := base64.StdEncoding.DecodeString(base64Img)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", fmt.Errorf("error decoding base64:%v", err)
 | 
			
		||||
	}
 | 
			
		||||
	objectKey := fmt.Sprintf("%s/%d.png", s.config.SubDir, time.Now().UnixMicro())
 | 
			
		||||
	ret := storage.PutRet{}
 | 
			
		||||
	extra := storage.PutExtra{}
 | 
			
		||||
	// 上传文件字节数据
 | 
			
		||||
	err = s.uploader.Put(context.Background(), &ret, s.putPolicy.UploadToken(s.mac), objectKey, bytes.NewReader(imageData), int64(len(imageData)), &extra)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return "", err
 | 
			
		||||
	}
 | 
			
		||||
	return fmt.Sprintf("%s/%s", s.config.Domain, ret.Key), nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (s QinNiuOss) Delete(fileURL string) error {
 | 
			
		||||
	objectName := filepath.Base(fileURL)
 | 
			
		||||
	key := fmt.Sprintf("%s/%s", s.dir, objectName)
 | 
			
		||||
	return s.manager.Delete(s.config.Bucket, key)
 | 
			
		||||
	var objectKey string
 | 
			
		||||
	if strings.HasPrefix(fileURL, "http") {
 | 
			
		||||
		filename := filepath.Base(fileURL)
 | 
			
		||||
		objectKey = fmt.Sprintf("%s/%s", s.config.SubDir, filename)
 | 
			
		||||
	} else {
 | 
			
		||||
		objectKey = fileURL
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return s.manager.Delete(s.config.Bucket, objectKey)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
var _ Uploader = QinNiuOss{}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,29 @@
 | 
			
		||||
package oss
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import "github.com/gin-gonic/gin"
 | 
			
		||||
 | 
			
		||||
const Local = "LOCAL"
 | 
			
		||||
const Minio = "MINIO"
 | 
			
		||||
const QiNiu = "QINIU"
 | 
			
		||||
const AliYun = "ALIYUN"
 | 
			
		||||
 | 
			
		||||
type File struct {
 | 
			
		||||
	Name   string `json:"name"`
 | 
			
		||||
	ObjKey string `json:"obj_key"`
 | 
			
		||||
	Size   int64  `json:"size"`
 | 
			
		||||
	URL    string `json:"url"`
 | 
			
		||||
	Ext    string `json:"ext"`
 | 
			
		||||
}
 | 
			
		||||
type Uploader interface {
 | 
			
		||||
	PutFile(ctx *gin.Context, name string) (string, error)
 | 
			
		||||
	PutFile(ctx *gin.Context, name string) (File, error)
 | 
			
		||||
	PutImg(imageURL string, useProxy bool) (string, error)
 | 
			
		||||
	PutBase64(imageData string) (string, error)
 | 
			
		||||
	Delete(fileURL string) error
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,7 +1,14 @@
 | 
			
		||||
package oss
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
@@ -9,11 +16,6 @@ type UploaderManager struct {
 | 
			
		||||
	handler Uploader
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
const Local = "LOCAL"
 | 
			
		||||
const Minio = "MINIO"
 | 
			
		||||
const QiNiu = "QINIU"
 | 
			
		||||
const AliYun = "ALIYUN"
 | 
			
		||||
 | 
			
		||||
func NewUploaderManager(config *types.AppConfig) (*UploaderManager, error) {
 | 
			
		||||
	active := Local
 | 
			
		||||
	if config.OSS.Active != "" {
 | 
			
		||||
 
 | 
			
		||||
@@ -1,9 +1,16 @@
 | 
			
		||||
package payment
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"chatplus/core/types"
 | 
			
		||||
	logger2 "chatplus/logger"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	logger2 "geekai/logger"
 | 
			
		||||
	"github.com/smartwalle/alipay/v3"
 | 
			
		||||
	"log"
 | 
			
		||||
	"net/url"
 | 
			
		||||
 
 | 
			
		||||
							
								
								
									
										169
									
								
								api/service/payment/hupipay_serive.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										169
									
								
								api/service/payment/hupipay_serive.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,169 @@
 | 
			
		||||
package payment
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strconv"
 | 
			
		||||
	"strings"
 | 
			
		||||
	"time"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type HuPiPayService struct {
 | 
			
		||||
	appId     string
 | 
			
		||||
	appSecret string
 | 
			
		||||
	apiURL    string
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewHuPiPay(config *types.AppConfig) *HuPiPayService {
 | 
			
		||||
	return &HuPiPayService{
 | 
			
		||||
		appId:     config.HuPiPayConfig.AppId,
 | 
			
		||||
		appSecret: config.HuPiPayConfig.AppSecret,
 | 
			
		||||
		apiURL:    config.HuPiPayConfig.ApiURL,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type HuPiPayReq struct {
 | 
			
		||||
	AppId        string `json:"appid"`
 | 
			
		||||
	Version      string `json:"version"`
 | 
			
		||||
	TradeOrderId string `json:"trade_order_id"`
 | 
			
		||||
	TotalFee     string `json:"total_fee"`
 | 
			
		||||
	Title        string `json:"title"`
 | 
			
		||||
	NotifyURL    string `json:"notify_url"`
 | 
			
		||||
	ReturnURL    string `json:"return_url"`
 | 
			
		||||
	WapName      string `json:"wap_name"`
 | 
			
		||||
	CallbackURL  string `json:"callback_url"`
 | 
			
		||||
	Time         string `json:"time"`
 | 
			
		||||
	NonceStr     string `json:"nonce_str"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type HuPiResp struct {
 | 
			
		||||
	Openid    interface{} `json:"openid"`
 | 
			
		||||
	UrlQrcode string      `json:"url_qrcode"`
 | 
			
		||||
	URL       string      `json:"url"`
 | 
			
		||||
	ErrCode   int         `json:"errcode"`
 | 
			
		||||
	ErrMsg    string      `json:"errmsg,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Pay 执行支付请求操作
 | 
			
		||||
func (s *HuPiPayService) Pay(params HuPiPayReq) (HuPiResp, error) {
 | 
			
		||||
	data := url.Values{}
 | 
			
		||||
	simple := strconv.FormatInt(time.Now().Unix(), 10)
 | 
			
		||||
	params.AppId = s.appId
 | 
			
		||||
	params.Time = simple
 | 
			
		||||
	params.NonceStr = simple
 | 
			
		||||
	encode := utils.JsonEncode(params)
 | 
			
		||||
	m := make(map[string]string)
 | 
			
		||||
	_ = utils.JsonDecode(encode, &m)
 | 
			
		||||
	for k, v := range m {
 | 
			
		||||
		data.Add(k, fmt.Sprintf("%v", v))
 | 
			
		||||
	}
 | 
			
		||||
	// 生成签名
 | 
			
		||||
	data.Add("hash", s.Sign(data))
 | 
			
		||||
	// 发送支付请求
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/payment/do.html", s.apiURL)
 | 
			
		||||
	resp, err := http.PostForm(apiURL, data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return HuPiResp{}, fmt.Errorf("error with requst api: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	all, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return HuPiResp{}, fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var res HuPiResp
 | 
			
		||||
	err = utils.JsonDecode(string(all), &res)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return HuPiResp{}, fmt.Errorf("error with decode payment result: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if res.ErrCode != 0 {
 | 
			
		||||
		return HuPiResp{}, fmt.Errorf("error with generate pay url: %s", res.ErrMsg)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	return res, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Sign 签名方法
 | 
			
		||||
func (s *HuPiPayService) Sign(params url.Values) string {
 | 
			
		||||
	params.Del(`Sign`)
 | 
			
		||||
	var keys = make([]string, 0, 0)
 | 
			
		||||
	for key := range params {
 | 
			
		||||
		if params.Get(key) != `` {
 | 
			
		||||
			keys = append(keys, key)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	sort.Strings(keys)
 | 
			
		||||
 | 
			
		||||
	var pList = make([]string, 0, 0)
 | 
			
		||||
	for _, key := range keys {
 | 
			
		||||
		var value = strings.TrimSpace(params.Get(key))
 | 
			
		||||
		if len(value) > 0 {
 | 
			
		||||
			pList = append(pList, key+"="+value)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var src = strings.Join(pList, "&")
 | 
			
		||||
	src += s.appSecret
 | 
			
		||||
 | 
			
		||||
	md5bs := md5.Sum([]byte(src))
 | 
			
		||||
	return hex.EncodeToString(md5bs[:])
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Check 校验订单状态
 | 
			
		||||
func (s *HuPiPayService) Check(tradeNo string) error {
 | 
			
		||||
	data := url.Values{}
 | 
			
		||||
	data.Add("appid", s.appId)
 | 
			
		||||
	data.Add("open_order_id", tradeNo)
 | 
			
		||||
	stamp := strconv.FormatInt(time.Now().Unix(), 10)
 | 
			
		||||
	data.Add("time", stamp)
 | 
			
		||||
	data.Add("nonce_str", stamp)
 | 
			
		||||
	data.Add("hash", s.Sign(data))
 | 
			
		||||
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/payment/query.html", s.apiURL)
 | 
			
		||||
	resp, err := http.PostForm(apiURL, data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with http reqeust: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var r struct {
 | 
			
		||||
		ErrCode int `json:"errcode"`
 | 
			
		||||
		Data    struct {
 | 
			
		||||
			Status      string `json:"status"`
 | 
			
		||||
			OpenOrderId string `json:"open_order_id"`
 | 
			
		||||
		} `json:"data,omitempty"`
 | 
			
		||||
		ErrMsg string `json:"errmsg"`
 | 
			
		||||
		Hash   string `json:"hash"`
 | 
			
		||||
	}
 | 
			
		||||
	err = utils.JsonDecode(string(body), &r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.ErrCode == 0 && r.Data.Status == "OD" {
 | 
			
		||||
		return nil
 | 
			
		||||
	} else {
 | 
			
		||||
		logger.Debugf("%+v", r)
 | 
			
		||||
		return errors.New("order not paid:" + r.ErrMsg)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
							
								
								
									
										155
									
								
								api/service/payment/payjs_service.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										155
									
								
								api/service/payment/payjs_service.go
									
									
									
									
									
										Normal file
									
								
							@@ -0,0 +1,155 @@
 | 
			
		||||
package payment
 | 
			
		||||
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
// * Copyright 2023 The Geek-AI Authors. All rights reserved.
 | 
			
		||||
// * Use of this source code is governed by a Apache-2.0 license
 | 
			
		||||
// * that can be found in the LICENSE file.
 | 
			
		||||
// * @Author yangjian102621@163.com
 | 
			
		||||
// * +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"crypto/md5"
 | 
			
		||||
	"encoding/hex"
 | 
			
		||||
	"errors"
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"geekai/core/types"
 | 
			
		||||
	"geekai/utils"
 | 
			
		||||
	"io"
 | 
			
		||||
	"net/http"
 | 
			
		||||
	"net/url"
 | 
			
		||||
	"sort"
 | 
			
		||||
	"strings"
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
type PayJS struct {
 | 
			
		||||
	config *types.JPayConfig
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func NewPayJS(appConfig *types.AppConfig) *PayJS {
 | 
			
		||||
	return &PayJS{
 | 
			
		||||
		config: &appConfig.JPayConfig,
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type JPayReq struct {
 | 
			
		||||
	TotalFee   int    `json:"total_fee"`
 | 
			
		||||
	OutTradeNo string `json:"out_trade_no"`
 | 
			
		||||
	Subject    string `json:"body"`
 | 
			
		||||
	NotifyURL  string `json:"notify_url"`
 | 
			
		||||
	ReturnURL  string `json:"callback_url"`
 | 
			
		||||
}
 | 
			
		||||
type JPayReps struct {
 | 
			
		||||
	OutTradeNo string `json:"out_trade_no"`
 | 
			
		||||
	OrderId    string `json:"payjs_order_id"`
 | 
			
		||||
	ReturnCode int    `json:"return_code"`
 | 
			
		||||
	ReturnMsg  string `json:"return_msg"`
 | 
			
		||||
	Sign       string `json:"Sign"`
 | 
			
		||||
	TotalFee   string `json:"total_fee"`
 | 
			
		||||
	CodeUrl    string `json:"code_url,omitempty"`
 | 
			
		||||
	Qrcode     string `json:"qrcode,omitempty"`
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r JPayReps) IsOK() bool {
 | 
			
		||||
	return r.ReturnMsg == "SUCCESS"
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (js *PayJS) Pay(param JPayReq) JPayReps {
 | 
			
		||||
	param.NotifyURL = js.config.NotifyURL
 | 
			
		||||
	var p = url.Values{}
 | 
			
		||||
	encode := utils.JsonEncode(param)
 | 
			
		||||
	m := make(map[string]interface{})
 | 
			
		||||
	_ = utils.JsonDecode(encode, &m)
 | 
			
		||||
	for k, v := range m {
 | 
			
		||||
		p.Add(k, fmt.Sprintf("%v", v))
 | 
			
		||||
	}
 | 
			
		||||
	p.Add("mchid", js.config.AppId)
 | 
			
		||||
 | 
			
		||||
	p.Add("sign", js.sign(p))
 | 
			
		||||
 | 
			
		||||
	cli := http.Client{}
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/api/native", js.config.ApiURL)
 | 
			
		||||
	r, err := cli.PostForm(apiURL, p)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return JPayReps{ReturnMsg: err.Error()}
 | 
			
		||||
	}
 | 
			
		||||
	defer r.Body.Close()
 | 
			
		||||
	bs, err := io.ReadAll(r.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return JPayReps{ReturnMsg: err.Error()}
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var data JPayReps
 | 
			
		||||
	err = utils.JsonDecode(string(bs), &data)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return JPayReps{ReturnMsg: err.Error()}
 | 
			
		||||
	}
 | 
			
		||||
	return data
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (js *PayJS) PayH5(p url.Values) string {
 | 
			
		||||
	p.Add("mchid", js.config.AppId)
 | 
			
		||||
	p.Add("sign", js.sign(p))
 | 
			
		||||
	return fmt.Sprintf("%s/api/cashier?%s", js.config.ApiURL, p.Encode())
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (js *PayJS) sign(params url.Values) string {
 | 
			
		||||
	params.Del(`sign`)
 | 
			
		||||
	var keys = make([]string, 0, 0)
 | 
			
		||||
	for key := range params {
 | 
			
		||||
		if params.Get(key) != `` {
 | 
			
		||||
			keys = append(keys, key)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	sort.Strings(keys)
 | 
			
		||||
 | 
			
		||||
	var pList = make([]string, 0, 0)
 | 
			
		||||
	for _, key := range keys {
 | 
			
		||||
		var value = strings.TrimSpace(params.Get(key))
 | 
			
		||||
		if len(value) > 0 {
 | 
			
		||||
			pList = append(pList, key+"="+value)
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
	var src = strings.Join(pList, "&")
 | 
			
		||||
	src += "&key=" + js.config.PrivateKey
 | 
			
		||||
 | 
			
		||||
	md5bs := md5.Sum([]byte(src))
 | 
			
		||||
	md5res := hex.EncodeToString(md5bs[:])
 | 
			
		||||
	return strings.ToUpper(md5res)
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
// Check 查询订单支付状态
 | 
			
		||||
// @param tradeNo 支付平台交易 ID
 | 
			
		||||
func (js *PayJS) Check(tradeNo string) error {
 | 
			
		||||
	apiURL := fmt.Sprintf("%s/api/check", js.config.ApiURL)
 | 
			
		||||
	params := url.Values{}
 | 
			
		||||
	params.Add("payjs_order_id", tradeNo)
 | 
			
		||||
	params.Add("sign", js.sign(params))
 | 
			
		||||
	data := strings.NewReader(params.Encode())
 | 
			
		||||
	resp, err := http.Post(apiURL, "application/x-www-form-urlencoded", data)
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with http reqeust: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	defer resp.Body.Close()
 | 
			
		||||
	body, err := io.ReadAll(resp.Body)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with reading response: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	var r struct {
 | 
			
		||||
		ReturnCode int `json:"return_code"`
 | 
			
		||||
		Status     int `json:"status"`
 | 
			
		||||
	}
 | 
			
		||||
	err = utils.JsonDecode(string(body), &r)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return fmt.Errorf("error with decode response: %v", err)
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	if r.ReturnCode == 1 && r.Status == 1 {
 | 
			
		||||
		return nil
 | 
			
		||||
	} else {
 | 
			
		||||
		logger.Errorf("PayJs 支付验证响应:%s", string(body))
 | 
			
		||||
		return errors.New("order not paid")
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
Some files were not shown because too many files have changed in this diff Show More
		Reference in New Issue
	
	Block a user