@@ -21,16 +21,37 @@ import EllipsisIcon from "./EllipsisIcon"
2121import { SimpleSelectField } from "../SimpleSelect/SimpleSelect"
2222import { useFetch } from "./utils"
2323import { SelectChangeEvent } from "@mui/material/Select"
24+ import type { MathJax3Config } from "better-react-mathjax"
2425import { MathJaxContext } from "better-react-mathjax"
26+ import deepmerge from "@mui/utils/deepmerge"
2527
2628const ConditionalMathJaxWrapper : React . FC < {
2729 useMathJax : boolean
30+ config ?: MathJax3Config
2831 children : React . ReactNode
29- } > = ( { useMathJax, children } ) => {
32+ } > = ( { useMathJax, config = { } , children } ) => {
3033 if ( ! useMathJax ) {
3134 return < > { children } </ >
3235 }
33- return < MathJaxContext > { children } </ MathJaxContext >
36+
37+ return (
38+ < MathJaxContext
39+ config = { deepmerge (
40+ {
41+ startup : {
42+ typeset : false ,
43+ } ,
44+ loader : { load : [ "[tex]/boldsymbol" ] } ,
45+ tex : {
46+ packages : { "[+]" : [ "boldsymbol" ] } ,
47+ } ,
48+ } ,
49+ config ,
50+ ) }
51+ >
52+ { children }
53+ </ MathJaxContext >
54+ )
3455}
3556
3657const classes = {
@@ -251,13 +272,13 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
251272 scrollElement,
252273 ref,
253274 useMathJax = false ,
275+ mathJaxConfig,
254276 onSubmit,
255277 problemSetListUrl,
256278 problemSetInitialMessages,
257279 problemSetEmptyMessages,
258280 ...others // Could contain data attributes
259281} ) => {
260- const containerRef = useRef < HTMLDivElement > ( null )
261282 const messagesContainerRef = useRef < HTMLDivElement > ( null )
262283 const chatScreenRef = useRef < HTMLDivElement > ( null )
263284 const promptInputRef = useRef < HTMLDivElement > ( null )
@@ -291,7 +312,9 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
291312 const [ showEntryScreen , setShowEntryScreen ] = useState ( entryScreenEnabled )
292313 useEffect ( ( ) => {
293314 if ( ! showEntryScreen ) {
294- promptInputRef . current ?. querySelector ( "input" ) ?. focus ( )
315+ promptInputRef . current
316+ ?. querySelector ( "input" )
317+ ?. focus ( { preventScroll : true } )
295318 }
296319 } , [ showEntryScreen ] )
297320
@@ -318,7 +341,7 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
318341 ] )
319342 }
320343 }
321- } , [ problemSetListResponse ] )
344+ } , [ problemSetListResponse , problemSetEmptyMessages , setMessages ] )
322345
323346 useEffect ( ( ) => {
324347 if (
@@ -366,7 +389,7 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
366389 const externalScroll = ! ! scrollElement
367390
368391 return (
369- < Container className = { className } ref = { containerRef } >
392+ < Container className = { className } >
370393 { showEntryScreen ? (
371394 < EntryScreen
372395 className = { classes . entryScreenContainer }
@@ -420,41 +443,37 @@ const AiChatDisplay: FC<AiChatDisplayProps> = ({
420443 ) : null
421444 }
422445 />
423- < ConditionalMathJaxWrapper useMathJax = { useMathJax } >
446+ < ConditionalMathJaxWrapper
447+ useMathJax = { useMathJax }
448+ config = { mathJaxConfig }
449+ >
424450 < MessagesContainer
425451 className = { classes . messagesContainer }
426452 externalScroll = { externalScroll }
427453 ref = { messagesContainerRef }
428454 >
429- { messages . map ( ( m : Message , i ) => {
430- // Our Markdown+Mathjax has issues when rendering streaming display math
431- // Force a re-render of the last (streaming) message when it's done loading.
432- const key =
433- i === messages . length - 1 && isLoading
434- ? `isLoading-${ m . id } `
435- : m . id
455+ { messages . map ( ( message : Message , index : number ) => {
436456 return (
437457 < MessageRow
438- key = { key }
439- data-chat-role = { m . role }
458+ key = { index }
459+ data-chat-role = { message . role }
440460 className = { classNames ( classes . messageRow , {
441- [ classes . messageRowUser ] : m . role === "user" ,
442- [ classes . messageRowAssistant ] : m . role === "assistant" ,
461+ [ classes . messageRowUser ] : message . role === "user" ,
462+ [ classes . messageRowAssistant ] :
463+ message . role === "assistant" ,
443464 } ) }
444465 >
445466 < Message className = { classes . message } >
446- < VisuallyHidden as = { m . role === "user" ? "h5" : "h6" } >
447- { m . role === "user"
467+ < VisuallyHidden
468+ as = { message . role === "user" ? "h5" : "h6" }
469+ >
470+ { message . role === "user"
448471 ? "You said: "
449472 : "Assistant said: " }
450473 </ VisuallyHidden >
451- { useMathJax ? (
452- < Markdown enableMathjax = { true } >
453- { replaceMathjax ( m . content ) }
454- </ Markdown >
455- ) : (
456- < Markdown > { m . content } </ Markdown >
457- ) }
474+ < Markdown useMathJax = { useMathJax } >
475+ { message . content }
476+ </ Markdown >
458477 </ Message >
459478 </ MessageRow >
460479 )
@@ -583,21 +602,4 @@ const AiChat: FC<AiChatProps> = ({
583602 )
584603}
585604
586- // react-markdown expects Mathjax delimiters to be $...$ or $$...$$
587- // the prompt for the tutorbot asks for Mathjax tags with $ format but
588- // the LLM does not get it right all the time
589- // this function replaces the Mathjax tags with the correct format
590- // eventually we will probably be able to remove this as LLMs get better
591- function replaceMathjax ( inputString : string ) : string {
592- // Replace instances of \(...\) and \[...\] Mathjax tags with $...$
593- // and $$...$$ tags.
594- const INLINE_MATH_REGEX = / \\ \( ( .* ?) \\ \) / g
595- const DISPLAY_MATH_REGEX = / \\ \[ ( ( [ \s \S ] * ?) ) \\ \] / g
596- inputString = inputString . replace (
597- INLINE_MATH_REGEX ,
598- ( _match , p1 ) => `$${ p1 } $` ,
599- )
600- return inputString . replace ( DISPLAY_MATH_REGEX , ( _match , p1 ) => `$$${ p1 } $$` )
601- }
602-
603- export { AiChatDisplay , AiChat , replaceMathjax }
605+ export { AiChatDisplay , AiChat }
0 commit comments