Skip to content

Commit 87668bd

Browse files
brucelbhBruce_Liu
andauthored
Add llm parameter to qa_endpoint (#10)
Co-authored-by: Bruce_Liu <[email protected]>
1 parent da24611 commit 87668bd

File tree

2 files changed

+5
-4
lines changed

2 files changed

+5
-4
lines changed

be/api.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from db.db_service import DBService
1010
from workflow import Workflow
1111
from nlu.llm_azure import model
12+
from nlu.llm_client import llm
1213
from sqlcoder.llm_sqlcoder import SqlCoderLLM
1314
from service import DataLinguaService
1415

@@ -35,15 +36,15 @@ class QARequest(BaseModel):
3536

3637
from fastapi import Cookie
3738

38-
@app.post("/qa")
39-
def qa_endpoint(req: QARequest, sid: Optional[str] = Cookie(None)):
39+
@app.post("/qa/{model_name}")
40+
def qa_endpoint(req: QARequest, sid: Optional[str] = Cookie(None), model_name: str = "client"):
4041
# sid从cookie读取,conversation_id由前端传入
4142
conversation_id = req.conversation_id
4243
# 记录sid->conversation_id映射
4344
if sid:
4445
sid_conversation_map.setdefault(sid, []).append(conversation_id)
4546
# 每次新建 workflow
46-
nlu_agent = NLUAgent(model)
47+
nlu_agent = NLUAgent(model if model_name == 'remote' else llm)
4748
sqlcoder_llm = SqlCoderLLM()
4849
sqlcoder_agent = SQLCoderAgent(sqlcoder_llm)
4950
db_service = DBService(db_path="Chinook.db")

be/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
def run_api():
44
import uvicorn
5-
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True)
5+
uvicorn.run("api:app", host="127.0.0.1", port=8000, reload=True)
66

77
if __name__ == "__main__":
88
run_api()

0 commit comments

Comments
 (0)