Skip to content

Commit b951fe6

Browse files
committed
avoid repeated langchain compatible check code
1 parent abf0efa commit b951fe6

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

src/neo4j_graphrag/generation/graphrag.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,15 +128,15 @@ def search(
128128
129129
"""
130130
if return_context is None:
131-
if isinstance(self.llm, LLMInterface):
131+
if self.is_langchain_compatible():
132+
return_context = True
133+
else: # e.g. LLMInterface
132134
warnings.warn(
133135
"The default value of 'return_context' will change from 'False'"
134136
" to 'True' in a future version.",
135137
DeprecationWarning,
136138
)
137139
return_context = False
138-
else:
139-
return_context = True
140140
try:
141141
validated_data = RagSearchModel(
142142
query_text=query_text,
@@ -164,9 +164,7 @@ def search(
164164
logger.debug("RAG: retriever_result=%s", prettify(retriever_result))
165165
logger.debug("RAG: prompt=%s", prompt)
166166

167-
if isinstance(self.llm, LLMInterfaceV2) or self.llm.__module__.startswith(
168-
"langchain"
169-
):
167+
if self.is_langchain_compatible():
170168
messages = legacy_inputs_to_messages(
171169
prompt=prompt,
172170
message_history=message_history,
@@ -206,9 +204,7 @@ def _build_query(
206204
summarization_prompt = self._chat_summary_prompt(
207205
message_history=message_history
208206
)
209-
if isinstance(self.llm, LLMInterfaceV2) or self.llm.__module__.startswith(
210-
"langchain"
211-
):
207+
if self.is_langchain_compatible():
212208
messages = legacy_inputs_to_messages(
213209
summarization_prompt,
214210
system_instruction=summary_system_message,
@@ -227,6 +223,12 @@ def _build_query(
227223
return self.conversation_prompt(summary=summary, current_query=query_text)
228224
return query_text
229225

226+
def is_langchain_compatible(self) -> bool:
227+
"""Checks if the LLM is compatible with LangChain."""
228+
return isinstance(self.llm, LLMInterfaceV2) or self.llm.__module__.startswith(
229+
"langchain"
230+
)
231+
230232
def _chat_summary_prompt(self, message_history: List[LLMMessage]) -> str:
231233
message_list = [
232234
f"{message['role']}: {message['content']}" for message in message_history

0 commit comments

Comments
 (0)