提交 15a93038 authored 作者: 王鹏飞's avatar 王鹏飞

refactor: 重构useWenku.js

上级 4776215b
...@@ -11,7 +11,7 @@ import { ...@@ -11,7 +11,7 @@ import {
import { ConfigProvider, Modal, Input, Button } from 'antd' import { ConfigProvider, Modal, Input, Button } from 'antd'
const { TextArea } = Input const { TextArea } = Input
import './AISearchModal.less' import './AISearchModal.less'
import { usePaperOutline, useOutlineToPaper } from '@/hooks/useWenku' import { usePaper } from '@/hooks/useWenku'
import { useCopyToClipboard } from 'react-use' import { useCopyToClipboard } from 'react-use'
export default function AIModal() { export default function AIModal() {
...@@ -20,7 +20,7 @@ export default function AIModal() { ...@@ -20,7 +20,7 @@ export default function AIModal() {
const [textIndent, setTextIndent] = useState(0) const [textIndent, setTextIndent] = useState(0)
const prePromptRef = useRef(null) const prePromptRef = useRef(null)
const messageScrollRef = useRef(null) const messageScrollRef = useRef(null)
const { messages, setMessages, isLoading, query } = usePaperOutline() const { messages, setMessages, isLoading, generateOutline, generatePaper } = usePaper()
useEffect(() => { useEffect(() => {
if (prePromptRef.current) { if (prePromptRef.current) {
const width = prePromptRef.current.offsetWidth + 10 const width = prePromptRef.current.offsetWidth + 10
...@@ -43,7 +43,7 @@ export default function AIModal() { ...@@ -43,7 +43,7 @@ export default function AIModal() {
} }
} }
const handleSearch = () => { const handleSearch = () => {
query(prePrompt + content) generateOutline(prePrompt + content)
setContent('') setContent('')
} }
...@@ -55,13 +55,11 @@ export default function AIModal() { ...@@ -55,13 +55,11 @@ export default function AIModal() {
} }
// 生成论文 // 生成论文
const { query: outlineToPaper } = useOutlineToPaper()
const handleGeneratePaper = async (msg) => { const handleGeneratePaper = async (msg) => {
setMessages((prevMessages) => { setMessages((prevMessages) => {
return [...prevMessages, { content: '正在生成长文...', role: 'ai', tips: '预计10分钟', queryID: msg.queryID }] return [...prevMessages, { content: '正在生成长文...', role: 'ai', tips: '预计10分钟', queryID: msg.queryID }]
}) })
const paper = await outlineToPaper({ userQuery: msg.userQuery, queryID: msg.queryID, outline: msg.content }) const paper = await generatePaper({ userQuery: msg.userQuery, queryID: msg.queryID, outline: msg.content })
setMessages((prevMessages) => { setMessages((prevMessages) => {
prevMessages.pop() prevMessages.pop()
return [...prevMessages, { content: '已为您生成初稿,请点击查看', role: 'ai', queryID: msg.queryID, paper }] return [...prevMessages, { content: '已为您生成初稿,请点击查看', role: 'ai', queryID: msg.queryID, paper }]
...@@ -101,7 +99,7 @@ export default function AIModal() { ...@@ -101,7 +99,7 @@ export default function AIModal() {
size="small" size="small"
icon={<UndoOutlined />} icon={<UndoOutlined />}
disabled={isLoading} disabled={isLoading}
onClick={() => query(msg.userQuery)}> onClick={() => generateOutline(msg.userQuery)}>
换个大纲 换个大纲
</Button> </Button>
<Button <Button
......
import { useState, useCallback } from 'react'
/**
* 通用异步任务处理hook
* @param {Function} taskFn 异步任务函数
* @returns {Object} { isLoading, error, execute }
*/
export function useAsyncTask(taskFn) {
const [isLoading, setIsLoading] = useState(false)
const [error, setError] = useState(null)
const execute = useCallback(
async (...args) => {
try {
setIsLoading(true)
setError(null)
return await taskFn(...args)
} catch (err) {
setError(err)
throw err
} finally {
setIsLoading(false)
}
},
[taskFn]
)
return { isLoading, error, execute }
}
import { useState, useEffect, useRef } from 'react' import { useState, useEffect, useRef } from 'react'
import { aiSearch, paperOutline, outlineToPaper, download } from '@/api/wenku' import { aiSearch, paperOutline, outlineToPaper, download } from '@/api/wenku'
import { useAsyncTask } from './useAsyncTask'
/**
* @typedef {Object} Message
* @property {string} content
* @property {'user' | 'ai'} role
* @property {Array<{title: string, url: string}>} [searchReferList]
* @property {string} [userQuery]
* @property {Chapter[]} [chapters]
* @property {string} [logID]
* @property {string} [queryID]
*/
/**
* @typedef {Object} Chapter
* @property {string} title
* @property {number} level
* @property {string} desc
* @property {string} chapter
*/
/**
* @typedef {Object} Paper
* @property {string} content
* @property {string} [downloadLink]
* @property {string} [docID]
*/
/**
* @typedef {Object} DownloadResult
* @property {string} download_link
*/
// AI搜索 // AI搜索
export function useSearch() { export function useSearch() {
const [messages, setMessages] = useState([]) const [messages, setMessages] = useState([])
const [isLoading, setIsLoading] = useState(false)
const query = async (userQuery) => { const { isLoading, execute } = useAsyncTask(async (userQuery) => {
try { setMessages((prevMessages) => [...prevMessages, { content: userQuery, role: 'user' }])
setIsLoading(true) const currentMessage = { content: '', searchReferList: [], role: 'ai', userQuery }
setMessages((prevMessages) => [...prevMessages, { content: userQuery, role: 'user' }])
const currentMessage = { content: '', searchReferList: [], role: 'ai', userQuery } await aiSearch({
await aiSearch({ body: JSON.stringify({ userQuery }),
body: JSON.stringify({ userQuery }), onmessage(message) {
onmessage(message) { try {
try { const data = JSON.parse(message.data)
const data = JSON.parse(message.data) const content = data.raw?.content || ''
const content = data.raw?.content || '' const searchReferList = data.raw?.searchReferList || []
const searchReferList = data.raw?.searchReferList || [] currentMessage.content += content
currentMessage.content += content if (searchReferList.length) {
if (searchReferList.length) { currentMessage.searchReferList.push(...searchReferList)
currentMessage.searchReferList.push(...searchReferList) setMessages((prevMessages) => {
setMessages((prevMessages) => { const lastMessage = prevMessages[prevMessages.length - 1]
const lastMessage = prevMessages[prevMessages.length - 1] if (lastMessage?.role === 'ai') {
if (lastMessage?.role === 'ai') { return [...prevMessages.slice(0, -1), currentMessage]
return [...prevMessages.slice(0, -1), currentMessage] }
} return [...prevMessages, currentMessage]
return [...prevMessages, currentMessage] })
})
}
} catch (error) {
console.log(error)
} }
}, } catch (error) {
}) console.error('Error parsing message:', error)
} catch (error) { }
console.log(error) },
} finally { })
setIsLoading(false) })
}
} return { messages, setMessages, isLoading, query: execute }
return { messages, setMessages, isLoading, query }
} }
function parseInput(inputStr) { function parseInput(inputStr) {
...@@ -95,76 +122,21 @@ function parseInput(inputStr) { ...@@ -95,76 +122,21 @@ function parseInput(inputStr) {
return result return result
} }
// 生成论文大纲 /**
export function usePaperOutline() { * 论文生成全流程管理
* @returns {{
* messages: Message[],
* paper: Paper,
* isLoading: boolean,
* generateOutline: (userQuery: string) => Promise<void>,
* generatePaper: (data: any) => Promise<Paper>,
* downloadPaper: (docID: string) => Promise<string>
* }}
*/
export function usePaper() {
const [messages, setMessages] = useState([]) const [messages, setMessages] = useState([])
const [isLoading, setIsLoading] = useState(false) const [paper, setPaper] = useState(null)
const query = async (userQuery) => {
try {
setIsLoading(true)
setMessages((prevMessages) => [...prevMessages, { content: userQuery, role: 'user' }])
const currentMessage = { content: '', role: 'ai', userQuery, chapters: [], logID: '', queryID: '' }
await paperOutline({
body: JSON.stringify({ userQuery }),
onmessage(message) {
try {
const data = JSON.parse(message.data)
const content = data.raw?.data || ''
if (data.logID) currentMessage.logID = data.logID
if (data.queryID) currentMessage.queryID = data.queryID
currentMessage.content += content
currentMessage.chapters = parseInput(currentMessage.content)
if (content) {
setMessages((prevMessages) => {
const lastMessage = prevMessages[prevMessages.length - 1]
if (lastMessage?.role === 'ai') {
return [...prevMessages.slice(0, -1), currentMessage]
}
return [...prevMessages, currentMessage]
})
}
} catch (error) {
console.log(error)
}
},
})
} catch (error) {
console.log(error)
} finally {
setIsLoading(false)
}
}
return { messages, setMessages, isLoading, query }
}
// 大纲生成论文
export function useOutlineToPaper() {
const { queryWithPolling } = useDownload()
const [isLoading, setIsLoading] = useState(false)
const [paper, setPaper] = useState('')
const query = async (data) => {
try {
setIsLoading(true)
const { raw } = await outlineToPaper(data)
setPaper(raw)
const downloadLink = await queryWithPolling({
docID: '280fe43ee63a580216fc700abb68a98271feac83' || raw.docID,
})
setPaper({ ...raw, downloadLink })
return { ...raw, downloadLink }
} catch (error) {
console.log(error)
} finally {
setIsLoading(false)
}
}
return { paper, setPaper, isLoading, query }
}
// 文档下载
export function useDownload() {
const [downloadLink, setDownloadLink] = useState('') const [downloadLink, setDownloadLink] = useState('')
const [isLoading, setIsLoading] = useState(false)
const timerRef = useRef(null) const timerRef = useRef(null)
const clearTimer = () => { const clearTimer = () => {
...@@ -174,25 +146,58 @@ export function useDownload() { ...@@ -174,25 +146,58 @@ export function useDownload() {
} }
} }
const query = async (data) => { const { isLoading: isOutlineLoading, execute: generateOutline } = useAsyncTask(async (userQuery) => {
try { setMessages((prevMessages) => [...prevMessages, { content: userQuery, role: 'user' }])
setIsLoading(true) const currentMessage = { content: '', role: 'ai', userQuery, chapters: [], logID: '', queryID: '' }
const { raw } = await download(data)
if (raw) { await paperOutline({
setDownloadLink(raw.download_link) body: JSON.stringify({ userQuery }),
return raw.download_link onmessage(message) {
} try {
} catch (error) { const data = JSON.parse(message.data)
console.error(error) const content = data.raw?.data || ''
} finally { if (data.logID) currentMessage.logID = data.logID
setIsLoading(false) if (data.queryID) currentMessage.queryID = data.queryID
currentMessage.content += content
currentMessage.chapters = parseInput(currentMessage.content)
if (content) {
setMessages((prevMessages) => {
const lastMessage = prevMessages[prevMessages.length - 1]
if (lastMessage?.role === 'ai') {
return [...prevMessages.slice(0, -1), currentMessage]
}
return [...prevMessages, currentMessage]
})
}
} catch (error) {
console.error('Error parsing message:', error)
}
},
})
})
const { isLoading: isPaperLoading, execute: generatePaper } = useAsyncTask(async (data) => {
const { raw } = await outlineToPaper(data)
setPaper(raw)
const downloadLink = await downloadWithPolling({
docID: raw.docID,
})
setPaper({ ...raw, downloadLink })
return { ...raw, downloadLink }
})
const { isLoading: isDownloadLoading, execute: downloadPaper } = useAsyncTask(async (docID) => {
const { raw } = await download({ docID })
if (raw) {
setDownloadLink(raw.download_link)
return raw.download_link
} }
} })
const queryWithPolling = async (data, interval = 3000) => { const downloadWithPolling = async (data, interval = 3000) => {
const poll = async (resolve, reject) => { const poll = async (resolve, reject) => {
try { try {
const link = await query(data) const link = await downloadPaper(data.docID)
if (link) { if (link) {
resolve(link) resolve(link)
return return
...@@ -209,11 +214,17 @@ export function useDownload() { ...@@ -209,11 +214,17 @@ export function useDownload() {
} }
useEffect(() => { useEffect(() => {
// 清除定时器的逻辑 return () => clearTimer()
return () => {
clearTimer()
}
}, []) }, [])
return { downloadLink, isLoading, query, queryWithPolling, clearTimer } return {
messages,
setMessages,
paper,
downloadLink,
isLoading: isOutlineLoading || isPaperLoading || isDownloadLoading,
generateOutline,
generatePaper,
downloadPaper,
}
} }
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论