Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions dags/model_training_dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from airflow import DAG
from airflow.operators.python import PythonOperator
from datetime import datetime, timedelta
import subprocess
import sys
import os

# Path to your Lore repo
LORE_PATH = r"C:\MY WORK\open-source\lore"

# Default args for DAG
default_args = {
'owner': 'mohsinkhan85090',
'depends_on_past': False,
'email_on_failure': False,
'email_on_retry': False,
'retries': 1,
'retry_delay': timedelta(minutes=5),
}

# Task 1: Train model
def train_model():
# Run the main script (adjust if your training function is elsewhere)
subprocess.run([sys.executable, os.path.join(LORE_PATH, 'lore', '__main__.py'), 'train'], check=True)

# Task 2: Evaluate model
def evaluate_model():
subprocess.run([sys.executable, os.path.join(LORE_PATH, 'lore', '__main__.py'), 'evaluate'], check=True)

# Define DAG
with DAG(
'lore_model_training',
default_args=default_args,
description='Automate Lore model training using Airflow',
schedule_interval='@daily',
start_date=datetime(2025, 9, 19),
catchup=False
) as dag:

train_task = PythonOperator(
task_id='train_model',
python_callable=train_model
)

evaluate_task = PythonOperator(
task_id='evaluate_model',
python_callable=evaluate_model
)

# Task dependencies
train_task >> evaluate_task
7 changes: 7 additions & 0 deletions lore/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,14 @@

import lore
from lore import ansi, env, util


from lore.util import timer, which
if platform.system() == "Windows":
subprocess.run([env.BIN_FLASK] + sys.argv[1:], shell=True)
else:
os.execv(env.BIN_FLASK, args)



logger = logging.getLogger(__name__)
Expand Down
23 changes: 15 additions & 8 deletions lore/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,24 @@

import glob
import locale
import os
import re
import socket
import subprocess
import sys
import platform
from io import open

import pkg_resources

from lore import ansi

import os
import sys
import subprocess
# Determine the path to flask.exe depending on OS
if sys.platform == "win32":
BIN_FLASK = os.path.normpath(os.path.join(os.getcwd(), ".venv", "Scripts", "flask.exe"))
else:
BIN_FLASK = os.path.join(os.getcwd(), ".venv", "bin", "flask")

# -- Python 2/3 Compatability ------------------------------------------------

Expand Down Expand Up @@ -170,10 +176,8 @@ def validate():
)
)


def launch():
"""Ensure that python is running from the Lore virtualenv past this point.
"""
"""Ensure that python is running from the Lore virtualenv past this point."""
if launched():
check_version()
os.chdir(ROOT)
Expand All @@ -188,8 +192,11 @@ def launch():
import lore.__main__
lore.__main__.install(None, None)

reboot('--env-launched')

# Windows-safe call
if platform.system() == "Windows":
subprocess.run([BIN_FLASK] + sys.argv[1:], shell=True)
else:
os.execv(BIN_FLASK, sys.argv)

def reboot(*args):
"""Reboot python in the Lore virtualenv
Expand Down Expand Up @@ -285,7 +292,7 @@ def get_config(path):
conf = open(path, 'rt').read()
conf = os.path.expandvars(conf)

config = configparser.SafeConfigParser()
config = configparser.ConfigParser()
if sys.version_info[0] == 2:
from io import StringIO
config.readfp(StringIO(unicode(conf)))
Expand Down
44 changes: 21 additions & 23 deletions lore/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,34 +338,32 @@ def report_exception(exc_type=None, value=None, tb=None):
sys.excepthook = report_exception

# Librato
import os
import threading

_librato = None
_librato_aggregator = None
_librato_timer = None
_librato_start = None
_librato_lock = threading.RLock()

# Only connect to Librato if both env variables are present
LIBRATO_USER = os.getenv('LIBRATO_USER')
LIBRATO_TOKEN = os.getenv('LIBRATO_TOKEN')

if LIBRATO_USER and LIBRATO_TOKEN:
try:
import librato
from librato.aggregator import Aggregator

# client side aggregation
LIBRATO_MIN_AGGREGATION_PERIOD = 5
LIBRATO_MAX_AGGREGATION_PERIOD = 60

_librato = librato.connect(LIBRATO_USER, LIBRATO_TOKEN)
logger.info('Connected to Librato with user: %s' % LIBRATO_USER)
except Exception as e:
_librato = None
if os.getenv('LIBRATO_USER'):
try:
_librato = librato.connect(os.getenv('LIBRATO_USER'), os.getenv('LIBRATO_TOKEN'))
_librato_aggregator = None
_librato_timer = None
_librato_start = None
_librato_lock = threading.RLock()
logger.info('connected to librato with user: %s' % os.getenv('LIBRATO_USER'))
except:
logger.exception('unable to start librato')
report_exception()
_librato = None
else:
logger.warning('librato variables not found')

except env.ModuleNotFoundError:
pass

logger.warning('Could not connect to Librato: %s' % str(e))
else:
logger.warning('Librato environment variables not found')


def librato_record(name, value):
global _librato, _librato_lock, _librato_aggregator, _librato_timer, _librato_start

Expand Down
10 changes: 9 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1,9 @@
-e .
click==8.2.1
colorama==0.4.6
Flask==0.12.5
itsdangerous==2.2.0
Jinja2==3.1.6
MarkupSafe==3.0.2
Werkzeug==0.16.1
-e .

Expand Down