Skip to content

Commit a316a11

Browse files
Merge pull request #21 from TheExplainthis/develop
Add MongoDB
2 parents f3f7467 + 4300f76 commit a316a11

File tree

3 files changed

+76
-7
lines changed

3 files changed

+76
-7
lines changed

main.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,18 @@
1515
from src.models import OpenAIModel
1616
from src.memory import Memory
1717
from src.logger import logger
18-
from src.storage import Storage
18+
from src.storage import Storage, FileStorage, MongoStorage
1919
from src.utils import get_role_and_content
2020
from src.service.youtube import Youtube, YoutubeTranscriptReader
2121
from src.service.website import Website, WebsiteReader
22+
from src.mongodb import mongodb
2223

2324
load_dotenv('.env')
2425

2526
app = Flask(__name__)
2627
line_bot_api = LineBotApi(os.getenv('LINE_CHANNEL_ACCESS_TOKEN'))
2728
handler = WebhookHandler(os.getenv('LINE_CHANNEL_SECRET'))
28-
storage = Storage('db.json')
29+
storage = None
2930
youtube = Youtube(step=4)
3031
website = Website()
3132

@@ -62,8 +63,9 @@ def handle_text_message(event):
6263
if not is_successful:
6364
raise ValueError('Invalid API token')
6465
model_management[user_id] = model
65-
api_keys[user_id] = api_key
66-
storage.save(api_keys)
66+
storage.save({
67+
user_id: api_key
68+
})
6769
msg = TextSendMessage(text='Token 有效,註冊成功')
6870

6971
elif text.startswith('/指令說明'):
@@ -180,6 +182,11 @@ def home():
180182

181183

182184
if __name__ == "__main__":
185+
if os.getenv('USE_MONGO'):
186+
mongodb.connect_to_database()
187+
storage = Storage(MongoStorage(mongodb.db))
188+
else:
189+
storage = Storage(FileStorage('db.json'))
183190
try:
184191
data = storage.load()
185192
for user_id in data.keys():

src/mongodb.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import os
2+
3+
from pymongo import MongoClient
4+
5+
6+
class MongoDB():
7+
"""
8+
Environment Variables:
9+
MONGODB__PATH
10+
MONGODB__DBNAME
11+
"""
12+
client: None
13+
db: None
14+
15+
def connect_to_database(self, mongo_path=None, db_name=None):
16+
mongo_path = mongo_path or os.getenv('MONGODB__PATH')
17+
db_name = db_name or os.getenv('MONGODB__DBNAME')
18+
self.client = MongoClient(mongo_path)
19+
assert self.client.config.command('ping')['ok'] == 1.0
20+
self.db = self.client[db_name]
21+
22+
23+
mongodb = MongoDB()

src/storage.py

Lines changed: 42 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,54 @@
11
import json
2+
import datetime
23

34

4-
class Storage():
5+
class FileStorage:
56
def __init__(self, file_name):
67
self.fine_name = file_name
8+
self.history = {}
79

810
def save(self, data):
11+
self.history.update(data)
912
with open(self.fine_name, 'w', newline='') as f:
10-
json.dump(data, f)
13+
json.dump(self.history, f)
1114

1215
def load(self):
1316
with open(self.fine_name, newline='') as jsonfile:
1417
data = json.load(jsonfile)
15-
return data
18+
self.history = data
19+
return self.history
20+
21+
22+
class MongoStorage:
23+
def __init__(self, db):
24+
self.db = db
25+
26+
def save(self, data):
27+
user_id, api_key = list(data.items())[0]
28+
self.db['api_key'].update_one({
29+
'user_id': user_id
30+
}, {
31+
'$set': {
32+
'user_id': user_id,
33+
'api_key': api_key,
34+
'created_at': datetime.datetime.utcnow()
35+
}
36+
}, upsert=True)
37+
38+
def load(self):
39+
data = list(self.db['api_key'].find())
40+
res = {}
41+
for i in range(len(data)):
42+
res[data[i]['user_id']] = data[i]['api_key']
43+
return res
44+
45+
46+
class Storage:
47+
def __init__(self, storage):
48+
self.storage = storage
49+
50+
def save(self, data):
51+
self.storage.save(data)
52+
53+
def load(self):
54+
return self.storage.load()

0 commit comments

Comments
 (0)