Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 53 additions & 56 deletions gpt_researcher/actions/retriever.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,62 +10,59 @@ def get_retriever(retriever: str):
retriever: Retriever class

"""
match retriever:
case "google":
from gpt_researcher.retrievers import GoogleSearch

return GoogleSearch
case "searx":
from gpt_researcher.retrievers import SearxSearch

return SearxSearch
case "searchapi":
from gpt_researcher.retrievers import SearchApiSearch

return SearchApiSearch
case "serpapi":
from gpt_researcher.retrievers import SerpApiSearch

return SerpApiSearch
case "serper":
from gpt_researcher.retrievers import SerperSearch

return SerperSearch
case "duckduckgo":
from gpt_researcher.retrievers import Duckduckgo

return Duckduckgo
case "bing":
from gpt_researcher.retrievers import BingSearch

return BingSearch
case "arxiv":
from gpt_researcher.retrievers import ArxivSearch

return ArxivSearch
case "tavily":
from gpt_researcher.retrievers import TavilySearch

return TavilySearch
case "exa":
from gpt_researcher.retrievers import ExaSearch

return ExaSearch
case "semantic_scholar":
from gpt_researcher.retrievers import SemanticScholarSearch

return SemanticScholarSearch
case "pubmed_central":
from gpt_researcher.retrievers import PubMedCentralSearch

return PubMedCentralSearch
case "custom":
from gpt_researcher.retrievers import CustomRetriever

return CustomRetriever

case _:
return None
if retriever == "google":
from gpt_researcher.retrievers import GoogleSearch
return GoogleSearch
elif retriever == "searx":
from gpt_researcher.retrievers import SearxSearch
return SearxSearch
elif retriever == "searchapi":
from gpt_researcher.retrievers import SearchApiSearch
return SearchApiSearch
elif retriever == "serpapi":
from gpt_researcher.retrievers import SerpApiSearch
return SerpApiSearch
elif retriever == "serper":
from gpt_researcher.retrievers import SerperSearch
return SerperSearch
elif retriever == "duckduckgo":
from gpt_researcher.retrievers import Duckduckgo
return Duckduckgo
elif retriever == "bing":
from gpt_researcher.retrievers import BingSearch
return BingSearch
elif retriever == "arxiv":
from gpt_researcher.retrievers import ArxivSearch
return ArxivSearch
elif retriever == "tavily":
from gpt_researcher.retrievers import TavilySearch
return TavilySearch
elif retriever == "exa":
from gpt_researcher.retrievers import ExaSearch
return ExaSearch
elif retriever == "semantic_scholar":
from gpt_researcher.retrievers import SemanticScholarSearch
return SemanticScholarSearch
elif retriever == "pubmed_central":
from gpt_researcher.retrievers import PubMedCentralSearch
return PubMedCentralSearch
elif retriever == "custom":
from gpt_researcher.retrievers import CustomRetriever
return CustomRetriever
elif retriever == "brave":
from gpt_researcher.retrievers import BraveSearch
return BraveSearch
elif retriever == "you":
from gpt_researcher.retrievers import YouSearch
return YouSearch
elif retriever == "perplexity":
from gpt_researcher.retrievers import PerplexitySearch
return PerplexitySearch
elif retriever == "local_documents":
from gpt_researcher.retrievers import LocalDocumentRetriever
return LocalDocumentRetriever
else:
raise ValueError(f"Retriever {retriever} not found")


def get_retrievers(headers: dict[str, str], cfg: Config):
Expand Down
31 changes: 16 additions & 15 deletions gpt_researcher/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,21 +61,22 @@ def _handle_deprecated_attributes(self) -> None:
os.environ["EMBEDDING_PROVIDER"] or self.embedding_provider
)

match os.environ["EMBEDDING_PROVIDER"]:
case "ollama":
self.embedding_model = os.environ["OLLAMA_EMBEDDING_MODEL"]
case "custom":
self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "custom")
case "openai":
self.embedding_model = "text-embedding-3-large"
case "azure_openai":
self.embedding_model = "text-embedding-3-large"
case "huggingface":
self.embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
case "google_genai":
self.embedding_model = "text-embedding-004"
case _:
raise Exception("Embedding provider not found.")
# Replace match statement with if-elif for Python 3.9 compatibility
embedding_provider = os.environ["EMBEDDING_PROVIDER"]
if embedding_provider == "ollama":
self.embedding_model = os.environ["OLLAMA_EMBEDDING_MODEL"]
elif embedding_provider == "custom":
self.embedding_model = os.getenv("OPENAI_EMBEDDING_MODEL", "custom")
elif embedding_provider == "openai":
self.embedding_model = "text-embedding-3-large"
elif embedding_provider == "azure_openai":
self.embedding_model = "text-embedding-3-large"
elif embedding_provider == "huggingface":
self.embedding_model = "sentence-transformers/all-MiniLM-L6-v2"
elif embedding_provider == "google_genai":
self.embedding_model = "text-embedding-004"
else:
raise Exception("Embedding provider not found.")

_deprecation_warning = (
"LLM_PROVIDER, FAST_LLM_MODEL and SMART_LLM_MODEL are deprecated and "
Expand Down
173 changes: 86 additions & 87 deletions gpt_researcher/memory/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,93 +27,92 @@
class Memory:
def __init__(self, embedding_provider: str, model: str, **embdding_kwargs: Any):
_embeddings = None
match embedding_provider:
case "custom":
from langchain_openai import OpenAIEmbeddings

_embeddings = OpenAIEmbeddings(
model=model,
openai_api_key=os.getenv("OPENAI_API_KEY", "custom"),
openai_api_base=os.getenv(
"OPENAI_BASE_URL", "http://localhost:1234/v1"
), # default for lmstudio
check_embedding_ctx_length=False,
**embdding_kwargs,
) # quick fix for lmstudio
case "openai":
from langchain_openai import OpenAIEmbeddings

_embeddings = OpenAIEmbeddings(model=model, **embdding_kwargs)
case "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings

_embeddings = AzureOpenAIEmbeddings(
model=model,
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
**embdding_kwargs,
)
case "cohere":
from langchain_cohere import CohereEmbeddings

_embeddings = CohereEmbeddings(model=model, **embdding_kwargs)
case "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

_embeddings = VertexAIEmbeddings(model=model, **embdding_kwargs)
case "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings

_embeddings = GoogleGenerativeAIEmbeddings(
model=model, **embdding_kwargs
)
case "fireworks":
from langchain_fireworks import FireworksEmbeddings

_embeddings = FireworksEmbeddings(model=model, **embdding_kwargs)
case "ollama":
from langchain_ollama import OllamaEmbeddings

_embeddings = OllamaEmbeddings(
model=model,
base_url=os.environ["OLLAMA_BASE_URL"],
**embdding_kwargs,
)
case "together":
from langchain_together import TogetherEmbeddings

_embeddings = TogetherEmbeddings(model=model, **embdding_kwargs)
case "mistralai":
from langchain_mistralai import MistralAIEmbeddings

_embeddings = MistralAIEmbeddings(model=model, **embdding_kwargs)
case "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings

_embeddings = HuggingFaceEmbeddings(model_name=model, **embdding_kwargs)
case "nomic":
from langchain_nomic import NomicEmbeddings

_embeddings = NomicEmbeddings(model=model, **embdding_kwargs)
case "voyageai":
from langchain_voyageai import VoyageAIEmbeddings

_embeddings = VoyageAIEmbeddings(
voyage_api_key=os.environ["VOYAGE_API_KEY"],
model=model,
**embdding_kwargs,
)
case "dashscope":
from langchain_community.embeddings import DashScopeEmbeddings

_embeddings = DashScopeEmbeddings(model=model, **embdding_kwargs)
case "bedrock":
from langchain_aws.embeddings import BedrockEmbeddings

_embeddings = BedrockEmbeddings(model_id=model, **embdding_kwargs)
case _:
raise Exception("Embedding not found.")
if embedding_provider == "custom":
from langchain_openai import OpenAIEmbeddings

_embeddings = OpenAIEmbeddings(
model=model,
openai_api_key=os.getenv("OPENAI_API_KEY", "custom"),
openai_api_base=os.getenv(
"OPENAI_BASE_URL", "http://localhost:1234/v1"
), # default for lmstudio
check_embedding_ctx_length=False,
**embdding_kwargs,
) # quick fix for lmstudio
elif embedding_provider == "openai":
from langchain_openai import OpenAIEmbeddings

_embeddings = OpenAIEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings

_embeddings = AzureOpenAIEmbeddings(
model=model,
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
**embdding_kwargs,
)
elif embedding_provider == "cohere":
from langchain_cohere import CohereEmbeddings

_embeddings = CohereEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

_embeddings = VertexAIEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings

_embeddings = GoogleGenerativeAIEmbeddings(
model=model, **embdding_kwargs
)
elif embedding_provider == "fireworks":
from langchain_fireworks import FireworksEmbeddings

_embeddings = FireworksEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "ollama":
from langchain_ollama import OllamaEmbeddings

_embeddings = OllamaEmbeddings(
model=model,
base_url=os.environ["OLLAMA_BASE_URL"],
**embdding_kwargs,
)
elif embedding_provider == "together":
from langchain_together import TogetherEmbeddings

_embeddings = TogetherEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "mistralai":
from langchain_mistralai import MistralAIEmbeddings

_embeddings = MistralAIEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings

_embeddings = HuggingFaceEmbeddings(model_name=model, **embdding_kwargs)
elif embedding_provider == "nomic":
from langchain_nomic import NomicEmbeddings

_embeddings = NomicEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "voyageai":
from langchain_voyageai import VoyageAIEmbeddings

_embeddings = VoyageAIEmbeddings(
voyage_api_key=os.environ["VOYAGE_API_KEY"],
model=model,
**embdding_kwargs,
)
elif embedding_provider == "dashscope":
from langchain_community.embeddings import DashScopeEmbeddings

_embeddings = DashScopeEmbeddings(model=model, **embdding_kwargs)
elif embedding_provider == "bedrock":
from langchain_aws.embeddings import BedrockEmbeddings

_embeddings = BedrockEmbeddings(model_id=model, **embdding_kwargs)
else:
raise ValueError(f"Embedding provider {embedding_provider} not found. Please install the missing library with `pip install [pip-package-name]`")

self._embeddings = _embeddings

Expand Down
Loading