Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions gpt_researcher/actions/agent_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ async def choose_agent(

try:
response = await create_chat_completion(
model=cfg.smart_llm_model,
model=cfg.fast_llm_model,
messages=[
{"role": "system", "content": f"{prompt_family.auto_agent_instructions()}"},
{"role": "user", "content": f"task: {query}"},
],
temperature=0.15,
llm_provider=cfg.smart_llm_provider,
llm_provider=cfg.fast_llm_provider,
llm_kwargs=cfg.llm_kwargs,
cost_callback=cost_callback,
**kwargs
Expand Down
21 changes: 21 additions & 0 deletions gpt_researcher/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,29 @@ def _set_llm_attributes(self) -> None:
self.smart_llm_provider, self.smart_llm_model = self.parse_llm(self.smart_llm)
self.strategic_llm_provider, self.strategic_llm_model = self.parse_llm(self.strategic_llm)
self.reasoning_effort = self.parse_reasoning_effort(os.getenv("REASONING_EFFORT"))

# Set base URLs for LLM and embedding endpoints
self.llm_base_url = getattr(self, 'llm_endpoint', 'http://localhost:8080/v1')
self.embedding_base_url = getattr(self, 'embedding_endpoint', 'http://localhost:8081/v1')

# Update LLM kwargs with the appropriate base URL
if not self.llm_kwargs.get('base_url'):
self.llm_kwargs['base_url'] = self.llm_base_url

# Update embedding kwargs with the appropriate base URL
if not self.embedding_kwargs.get('base_url'):
self.embedding_kwargs['base_url'] = self.embedding_base_url

def _handle_deprecated_attributes(self) -> None:
# Handle environment variables for endpoints
if os.getenv("LLM_ENDPOINT"):
self.llm_base_url = os.environ["LLM_ENDPOINT"]
self.llm_kwargs['base_url'] = self.llm_base_url

if os.getenv("EMBEDDING_ENDPOINT"):
self.embedding_base_url = os.environ["EMBEDDING_ENDPOINT"]
self.embedding_kwargs['base_url'] = self.embedding_base_url

if os.getenv("EMBEDDING_PROVIDER") is not None:
warnings.warn(
"EMBEDDING_PROVIDER is deprecated and will be removed soon. Use EMBEDDING instead.",
Expand Down
2 changes: 2 additions & 0 deletions gpt_researcher/config/variables/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,6 @@
"MCP_ALLOWED_ROOT_PATHS": [], # List of allowed root paths for local file access
"MCP_STRATEGY": "fast", # MCP execution strategy: "fast", "deep", "disabled"
"REASONING_EFFORT": "medium",
"LLM_ENDPOINT": "http://localhost:8080/v1", # LLM endpoint
"EMBEDDING_ENDPOINT": "http://localhost:8081/v1", # Embedding endpoint
}
23 changes: 20 additions & 3 deletions gpt_researcher/llm_provider/generic/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,10 +90,22 @@ def __init__(self, llm, chat_log: str | None = None, verbose: bool = True):
self.verbose = verbose
@classmethod
def from_provider(cls, provider: str, chat_log: str | None = None, verbose: bool=True, **kwargs: Any):
# Get the appropriate base URL from kwargs or environment
base_url = kwargs.pop('base_url', None) or os.environ.get(f"{provider.upper()}_BASE_URL")

if base_url:
kwargs['base_url'] = base_url
if verbose:
print(f"Using {provider} endpoint: {base_url}")

if provider == "openai":
_check_pkg("langchain_openai")
from langchain_openai import ChatOpenAI


# Handle custom endpoints for OpenAI-compatible APIs
if 'base_url' in kwargs and 'api_key' not in kwargs:
kwargs['api_key'] = 'dummy' # Some APIs don't require a key

llm = ChatOpenAI(**kwargs)
elif provider == "anthropic":
_check_pkg("langchain_anthropic")
Expand Down Expand Up @@ -133,8 +145,13 @@ def from_provider(cls, provider: str, chat_log: str | None = None, verbose: bool
_check_pkg("langchain_community")
_check_pkg("langchain_ollama")
from langchain_ollama import ChatOllama

llm = ChatOllama(base_url=os.environ["OLLAMA_BASE_URL"], **kwargs)

# Use provided base_url or fall back to environment variable
base_url = kwargs.pop('base_url', os.environ.get("OLLAMA_BASE_URL"))
if base_url:
kwargs['base_url'] = base_url

llm = ChatOllama(**kwargs)
elif provider == "together":
_check_pkg("langchain_together")
from langchain_together import ChatTogether
Expand Down
104 changes: 55 additions & 49 deletions gpt_researcher/memory/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,105 +27,111 @@


class Memory:
def __init__(self, embedding_provider: str, model: str, **embdding_kwargs: Any):
def __init__(self, embedding_provider: str, model: str, **embedding_kwargs: Any):
_embeddings = None

# Get base URL from kwargs or environment
base_url = embedding_kwargs.pop('base_url', None) or os.environ.get('EMBEDDING_ENDPOINT')

match embedding_provider:
case "custom":
case "custom" | "openai":
from langchain_openai import OpenAIEmbeddings


# For custom endpoints, use a dummy key if none provided
api_key = os.getenv("OPENAI_API_KEY", "dummy")
if embedding_provider == "custom" and not base_url:
base_url = os.getenv("OPENAI_BASE_URL", "http://localhost:1234/v1")

_embeddings = OpenAIEmbeddings(
model=model,
openai_api_key=os.getenv("OPENAI_API_KEY", "custom"),
openai_api_base=os.getenv(
"OPENAI_BASE_URL", "http://localhost:1234/v1"
), # default for lmstudio
openai_api_key=api_key,
openai_api_base=base_url,
check_embedding_ctx_length=False,
**embdding_kwargs,
) # quick fix for lmstudio
case "openai":
from langchain_openai import OpenAIEmbeddings

_embeddings = OpenAIEmbeddings(model=model, **embdding_kwargs)
**embedding_kwargs,
)
case "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings

_embeddings = AzureOpenAIEmbeddings(
model=model,
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
**embdding_kwargs,
openai_api_version=os.getenv("AZURE_OPENAI_API_VERSION", "2023-05-15"),
azure_deployment=model,
**embedding_kwargs,
)
case "cohere":
from langchain_cohere import CohereEmbeddings

_embeddings = CohereEmbeddings(model=model, **embdding_kwargs)
_embeddings = CohereEmbeddings(model=model, **embedding_kwargs)

case "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings

_embeddings = VertexAIEmbeddings(model=model, **embdding_kwargs)
_embeddings = VertexAIEmbeddings(model=model, **embedding_kwargs)

case "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings

_embeddings = GoogleGenerativeAIEmbeddings(
model=model, **embdding_kwargs
model=model,
**embedding_kwargs,
)

case "fireworks":
from langchain_fireworks import FireworksEmbeddings

_embeddings = FireworksEmbeddings(model=model, **embdding_kwargs)
_embeddings = FireworksEmbeddings(model=model, **embedding_kwargs)

case "gigachat":
from langchain_gigachat import GigaChatEmbeddings

_embeddings = GigaChatEmbeddings(model=model, **embdding_kwargs)
_embeddings = GigaChatEmbeddings(model=model, **embedding_kwargs)

case "ollama":
from langchain_ollama import OllamaEmbeddings

# Use provided base_url or fall back to environment variable
ollama_base = base_url or os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
_embeddings = OllamaEmbeddings(
model=model,
base_url=os.environ["OLLAMA_BASE_URL"],
**embdding_kwargs,
base_url=ollama_base,
**embedding_kwargs,
)

case "together":
from langchain_together import TogetherEmbeddings

_embeddings = TogetherEmbeddings(model=model, **embdding_kwargs)
_embeddings = TogetherEmbeddings(model=model, **embedding_kwargs)

case "mistralai":
from langchain_mistralai import MistralAIEmbeddings

_embeddings = MistralAIEmbeddings(model=model, **embdding_kwargs)
_embeddings = MistralAIEmbeddings(model=model, **embedding_kwargs)

case "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings

_embeddings = HuggingFaceEmbeddings(model_name=model, **embdding_kwargs)
_embeddings = HuggingFaceEmbeddings(
model_name=model, **embedding_kwargs
)

case "nomic":
from langchain_nomic import NomicEmbeddings

_embeddings = NomicEmbeddings(model=model, **embdding_kwargs)
_embeddings = NomicEmbeddings(model=model, **embedding_kwargs)

case "voyageai":
from langchain_voyageai import VoyageAIEmbeddings

_embeddings = VoyageAIEmbeddings(
voyage_api_key=os.environ["VOYAGE_API_KEY"],
model=model,
**embdding_kwargs,
voyage_api_key=os.getenv("VOYAGE_API_KEY"),
**embedding_kwargs,
)

case "dashscope":
from langchain_community.embeddings import DashScopeEmbeddings

_embeddings = DashScopeEmbeddings(model=model, **embdding_kwargs)
_embeddings = DashScopeEmbeddings(model=model, **embedding_kwargs)

case "bedrock":
from langchain_aws.embeddings import BedrockEmbeddings

_embeddings = BedrockEmbeddings(model_id=model, **embdding_kwargs)
_embeddings = BedrockEmbeddings(model_id=model, **embedding_kwargs)

case "aimlapi":
from langchain_openai import OpenAIEmbeddings

_embeddings = OpenAIEmbeddings(
model=model,
openai_api_key=os.getenv("AIMLAPI_API_KEY"),
openai_api_base=os.getenv("AIMLAPI_BASE_URL", "https://api.aimlapi.com/v1"),
**embdding_kwargs,
openai_api_key=os.getenv("OPENAI_API_KEY", "custom"),
openai_api_base=base_url or os.getenv("AIMLAPI_BASE_URL"),
**embedding_kwargs,
)
case _:
raise Exception("Embedding not found.")
Expand Down
Loading