@@ -128,15 +128,15 @@ def search(
128128
129129 """
130130 if return_context is None :
131- if isinstance (self .llm , LLMInterface ):
131+ if self .is_langchain_compatible ():
132+ return_context = True
133+ else : # e.g. LLMInterface
132134 warnings .warn (
133135 "The default value of 'return_context' will change from 'False'"
134136 " to 'True' in a future version." ,
135137 DeprecationWarning ,
136138 )
137139 return_context = False
138- else :
139- return_context = True
140140 try :
141141 validated_data = RagSearchModel (
142142 query_text = query_text ,
@@ -164,9 +164,7 @@ def search(
164164 logger .debug ("RAG: retriever_result=%s" , prettify (retriever_result ))
165165 logger .debug ("RAG: prompt=%s" , prompt )
166166
167- if isinstance (self .llm , LLMInterfaceV2 ) or self .llm .__module__ .startswith (
168- "langchain"
169- ):
167+ if self .is_langchain_compatible ():
170168 messages = legacy_inputs_to_messages (
171169 prompt = prompt ,
172170 message_history = message_history ,
@@ -206,9 +204,7 @@ def _build_query(
206204 summarization_prompt = self ._chat_summary_prompt (
207205 message_history = message_history
208206 )
209- if isinstance (self .llm , LLMInterfaceV2 ) or self .llm .__module__ .startswith (
210- "langchain"
211- ):
207+ if self .is_langchain_compatible ():
212208 messages = legacy_inputs_to_messages (
213209 summarization_prompt ,
214210 system_instruction = summary_system_message ,
@@ -227,6 +223,12 @@ def _build_query(
227223 return self .conversation_prompt (summary = summary , current_query = query_text )
228224 return query_text
229225
226+ def is_langchain_compatible (self ) -> bool :
227+ """Checks if the LLM is compatible with LangChain."""
228+ return isinstance (self .llm , LLMInterfaceV2 ) or self .llm .__module__ .startswith (
229+ "langchain"
230+ )
231+
230232 def _chat_summary_prompt (self , message_history : List [LLMMessage ]) -> str :
231233 message_list = [
232234 f"{ message ['role' ]} : { message ['content' ]} " for message in message_history
0 commit comments