ai.controller.js 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980
  1. const aiService = require('../service/ai/ai.service.js');
  2. const util = require('../utils/util');
  3. const { model: openaiModel } = require('../service/ai/openai.service');
  4. const { findRelevantContent } = require('../service/ai/embedding.service');
  5. const { getSystemPrompt } = require('../service/ai/prompt.service');
  6. const { streamText } = require('ai');
  7. function formatMessages(messages) {
  8. return messages.map((message) => ({
  9. ...message,
  10. content: Array.isArray(message.content)
  11. ? message.content.map((content) => ({
  12. ...content,
  13. ...(content.type === 'image_url' && {
  14. type: 'image',
  15. image: content.image_url.url.replace(/^data:image\/\w+;base64,/, ''),
  16. }),
  17. }))
  18. : message.content,
  19. }));
  20. }
  21. module.exports = {
  22. async codeGenerate(ctx) {
  23. const { message } = ctx.request.body;
  24. if (!message) {
  25. return util.fail(ctx, '请输入提示词');
  26. }
  27. try {
  28. const result = await aiService.ApiZhipuAi(message);
  29. util.success(ctx, {
  30. jsx: result[0],
  31. config: result[1],
  32. });
  33. } catch (error) {
  34. util.fail(ctx, error);
  35. }
  36. },
  37. async chatStream(ctx) {
  38. try {
  39. const { messages } = ctx.request.body;
  40. const lastMessage = messages[messages.length - 1];
  41. // 处理最后一条消息
  42. const lastMessageContent = Array.isArray(lastMessage.content)
  43. ? lastMessage.content
  44. .filter((c) => c.type === 'text')
  45. .map((c) => c.text)
  46. .join('')
  47. : lastMessage.content;
  48. // 获取相关上下文
  49. const relevantContent = await findRelevantContent(lastMessageContent);
  50. const systemPrompt = getSystemPrompt(relevantContent.map((c) => c.content).join('\n'));
  51. const { textStream } = streamText({
  52. model: openaiModel,
  53. system: systemPrompt,
  54. messages: formatMessages(messages),
  55. });
  56. ctx.set('Access-Control-Allow-Origin', '*');
  57. ctx.set('Access-Control-Allow-Methods', 'POST, GET, OPTIONS');
  58. ctx.set('Access-Control-Allow-Headers', 'Content-Type, Authorization, X-Requested-With');
  59. ctx.status = 200;
  60. for await (const textPart of textStream) {
  61. // 返回消息
  62. ctx.res.write(textPart);
  63. }
  64. ctx.res.end();
  65. util.success(ctx, '对话结束');
  66. } catch (error) {
  67. ctx.status = 500;
  68. console.error('Chat stream error:', error);
  69. ctx.body = { error: error.message || 'Internal Server Error' };
  70. }
  71. },
  72. };