BaseModelProvider.js 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. const { InMemoryChatMessageHistory } = require("@langchain/core/chat_history");
  2. const { BaseMessage, MessageContent } = require("@langchain/core/messages");
  3. const {
  4. ChatPromptTemplate,
  5. HumanMessagePromptTemplate,
  6. MessagesPlaceholder,
  7. } = require("@langchain/core/prompts");
  8. const {
  9. Runnable,
  10. RunnableWithMessageHistory,
  11. } = require("@langchain/core/runnables");
  12. const { z } = require("zod");
  13. class BaseModelProvider {
  14. static sessionIdHistoriesMap = {};
  15. static answerContentToText(content) {
  16. if (typeof content === "string") {
  17. return content;
  18. }
  19. return content
  20. .map((c) => {
  21. if (c.type === "text") {
  22. return c.text;
  23. }
  24. return "";
  25. })
  26. .join("");
  27. }
  28. async getModel() {
  29. if (!this.model) {
  30. this.model = await this.createModel();
  31. }
  32. return this.model;
  33. }
  34. createPrompt(options) {
  35. const { useHistory = true } = options ?? {};
  36. const prompt = ChatPromptTemplate.fromMessages(
  37. [
  38. useHistory ? new MessagesPlaceholder("history") : "",
  39. HumanMessagePromptTemplate.fromTemplate("{input}"),
  40. ].filter(Boolean)
  41. );
  42. return prompt;
  43. }
  44. async getHistory(sessionId, appendHistoryMessages) {
  45. if (BaseModelProvider.sessionIdHistoriesMap[sessionId] === undefined) {
  46. const messageHistory = new InMemoryChatMessageHistory();
  47. if (appendHistoryMessages && appendHistoryMessages.length > 0) {
  48. await messageHistory.addMessages(appendHistoryMessages);
  49. }
  50. BaseModelProvider.sessionIdHistoriesMap[sessionId] = messageHistory;
  51. }
  52. return BaseModelProvider.sessionIdHistoriesMap[sessionId];
  53. }
  54. createRunnableWithMessageHistory(chain, historyMessages) {
  55. return new RunnableWithMessageHistory({
  56. runnable: chain,
  57. getMessageHistory: async (sessionId) =>
  58. await this.getHistory(sessionId, historyMessages),
  59. inputMessagesKey: "input",
  60. historyMessagesKey: "history",
  61. });
  62. }
  63. async createRunnable(options) {
  64. const { useHistory = true, historyMessages = [], signal } = options ?? {};
  65. const model = await this.getModel();
  66. const prompt = await this.createPrompt({ useHistory });
  67. const chain = prompt.pipe(signal ? model.bind({ signal }) : model);
  68. return useHistory
  69. ? await this.createRunnableWithMessageHistory(
  70. chain,
  71. historyMessages || []
  72. )
  73. : chain;
  74. }
  75. async createStructuredOutputRunnable(options) {
  76. const {
  77. useHistory = true,
  78. historyMessages = [],
  79. zodSchema,
  80. } = options ?? {};
  81. const model = await this.getModel();
  82. const prompt = await this.createPrompt({ useHistory });
  83. const chain = prompt.pipe(model);
  84. return useHistory
  85. ? await this.createRunnableWithMessageHistory(
  86. chain,
  87. historyMessages || []
  88. )
  89. : chain;
  90. }
  91. }
  92. module.exports = {
  93. BaseModelProvider,
  94. };