1+ import sqlalchemy
2+ from sqlalchemy .sql .expression import Select
3+ from sqlalchemy .orm import Query
4+ from sqlalchemy .ext .automap import automap_base
5+ from sqlalchemy .ext .declarative import declared_attr , declarative_base
6+ import geoalchemy2
7+
8+ import pickle
9+ from enum import Enum
10+ from typing import Dict , Union , Any , Type
11+ import warnings
12+
13+ import pandas as pd
14+
15+ from .exceptions import ObjectNotFound
16+
17+ from .models import annotations
18+ from .models import auth
19+ from .models import core
20+ from .models import cv
21+ from .models import dataquality
22+ from .models import equipment
23+ from .models import extensionproperties
24+ from .models import externalidentifiers
25+ from .models import labanalyses
26+ from .models import provenance
27+ from .models import results
28+ from .models import samplingfeatures
29+ from .models import simulation
30+
31+
32+ class OutputFormats (Enum ):
33+ JSON = 'JSON'
34+ DATAFRAME = 'DATAFRAME'
35+ DICT = 'DICT'
36+
37+ class Base ():
38+
39+ @declared_attr
40+ def __tablename__ (self ) -> str :
41+ cls_name = str (self .__name__ )
42+ return cls_name .lower ()
43+
44+ @classmethod
45+ def from_dict (cls , attributes_dict :Dict ) -> object :
46+ """Alternative constructor that uses dictionary to populate attributes"""
47+ instance = cls .__new__ (cls )
48+ instance .__init__ ()
49+ for key , value in attributes_dict .items ():
50+ if hasattr (instance , key ):
51+ if value == '' : value = None
52+ setattr (instance , key , value )
53+ return instance
54+
55+ def to_dict (self ) -> Dict [str ,Any ]:
56+ """Converts attributes into a dictionary"""
57+ columns = self .__table__ .columns .keys ()
58+ output_dict = {}
59+ for column in columns :
60+ output_dict [column ] = getattr (self ,column )
61+ return output_dict
62+
63+ def update_from_dict (self , attributes_dict :Dict [str , any ]) -> None :
64+ """Updates instance attributes based on provided dictionary"""
65+ for key , value in attributes_dict .items ():
66+ if hasattr (self , key ):
67+ if value == '' : value = None
68+ setattr (self , key , value )
69+
70+ @classmethod
71+ def get_pkey_name (cls ) -> Union [str ,None ]:
72+ """ Returns the primary key field name for a given model"""
73+ columns = cls .__table__ .columns
74+ for column in columns :
75+ if column .primary_key : return column .name
76+ return None
77+
78+ class ODM2Engine :
79+
80+ def __init__ (self , session_maker :sqlalchemy .orm .sessionmaker ) -> None :
81+ self .session_maker = session_maker
82+
83+ def read_query (self ,
84+ query : Union [Query , Select ],
85+ output_format :OutputFormats = OutputFormats .JSON ,
86+ orient :str = 'records' ) -> Union [str , pd .DataFrame ]:
87+ with self .session_maker () as session :
88+ if isinstance (query , Select ):
89+ df = pd .read_sql (query , session .bind )
90+ else :
91+ df = pd .read_sql (query .statement , session .bind )
92+
93+ if output_format == OutputFormats .JSON :
94+ return df .to_json (orient = orient )
95+ elif output_format == OutputFormats .DATAFRAME :
96+ return df
97+ elif output_format == OutputFormats .DICT :
98+ return df .to_dict ()
99+ raise TypeError ("Unknown output format" )
100+
101+ def insert_query (self ) -> None :
102+ """Placeholder for bulk insert"""
103+ #accept dataframe & model
104+ #use pandas to_sql method to perform insert
105+ #if except return false or maybe raise error
106+ #else return true
107+ raise NotImplementedError
108+
109+ def create_object (self , obj :object ) -> Union [int , str ]:
110+ pkey_name = obj .get_pkey_name ()
111+ setattr (obj , pkey_name , None )
112+
113+ with self .session_maker () as session :
114+ session .add (obj )
115+ session .commit ()
116+ pkey_value = getattr (obj , pkey_name )
117+ return pkey_value
118+
119+ def read_object (self , model :Type [Base ], pkey :Union [int , str ],
120+ output_format : OutputFormats = OutputFormats .DICT ,
121+ orient :str = 'records' ) -> Dict [str , Any ]:
122+
123+ with self .session_maker () as session :
124+ obj = session .get (model , pkey )
125+ pkey_name = model .get_pkey_name ()
126+ if obj is None : raise ObjectNotFound (f"No '{ model .__name__ } ' object found with { pkey_name } = { pkey } " )
127+ session .commit ()
128+
129+ obj_dict = obj .to_dict ()
130+ if output_format == OutputFormats .DICT :
131+ return obj_dict
132+
133+ else :
134+ # convert to series if only one row
135+ keys = list (obj_dict .keys ())
136+ if not isinstance (obj_dict [keys [0 ]], list ):
137+ for key in keys :
138+ new_value = [obj_dict [key ]]
139+ obj_dict [key ] = new_value
140+
141+ obj_df = pd .DataFrame .from_dict (obj_dict )
142+ if output_format == OutputFormats .DATAFRAME :
143+ return obj_df
144+ elif output_format == OutputFormats .JSON :
145+ return obj_df .to_json (orient = orient )
146+ raise TypeError ("Unknown output format" )
147+
148+
149+ def update_object (self , model :Type [Base ], pkey :Union [int ,str ], data :Dict [str , Any ]) -> None :
150+ if not isinstance (data , dict ):
151+ data = data .dict ()
152+ pkey_name = model .get_pkey_name ()
153+ if pkey_name in data :
154+ data .pop (pkey_name )
155+ with self .session_maker () as session :
156+ obj = session .get (model , pkey )
157+ if obj is None : raise ObjectNotFound (f"No '{ model .__name__ } ' object found with { pkey_name } = { pkey } " )
158+ obj .update_from_dict (data )
159+ session .commit ()
160+
161+ def delete_object (self , model :Type [Base ], pkey :Union [int , str ]) -> None :
162+ with self .session_maker () as session :
163+ obj = session .get (model , pkey )
164+ pkey_name = model .get_pkey_name ()
165+ if obj is None : raise ObjectNotFound (f"No '{ model .__name__ } ' object found with { pkey_name } = { pkey } " )
166+ session .delete (obj )
167+ session .commit ()
168+
169+ class Models :
170+
171+ def __init__ (self , base_model ) -> None :
172+ self ._base_model = base_model
173+ self ._process_schema (annotations )
174+ self ._process_schema (auth )
175+ self ._process_schema (core )
176+ self ._process_schema (cv )
177+ self ._process_schema (dataquality )
178+ self ._process_schema (equipment )
179+ self ._process_schema (extensionproperties )
180+ self ._process_schema (externalidentifiers )
181+ self ._process_schema (labanalyses )
182+ self ._process_schema (provenance )
183+ self ._process_schema (results )
184+ self ._process_schema (samplingfeatures )
185+ self ._process_schema (simulation )
186+
187+ def _process_schema (self , schema :str ) -> None :
188+ classes = [c for c in dir (schema ) if not c .startswith ('__' )]
189+ base = tuple ([self ._base_model ])
190+ for class_name in classes :
191+ model = getattr (schema , class_name )
192+ model_attribs = self ._trim_dunders (dict (model .__dict__ .copy ()))
193+ extended_model = type (class_name , base , model_attribs )
194+ setattr (self , class_name , extended_model )
195+
196+ def _trim_dunders (self , dictionary :Dict [str , Any ]) -> Dict [str , Any ]:
197+ return { k :v for k , v in dictionary .items () if not k .startswith ('__' ) }
198+
199+ class ODM2DataModels ():
200+
201+ def __init__ (self , engine :sqlalchemy .engine , schema :str = 'odm2' , cache_path :str = None ) -> None :
202+
203+ self ._schema = schema
204+ self ._cache_path = cache_path
205+
206+ self ._engine = engine
207+ self ._session = sqlalchemy .orm .sessionmaker (self ._engine )
208+ self ._cached = False
209+ self .odm2_engine : ODM2Engine = ODM2Engine (self ._session )
210+
211+ self ._model_base = self ._prepare_model_base ()
212+ self .models = Models (self ._model_base )
213+ if not self ._cached :
214+ self ._prepare_automap_models ()
215+
216+ def _prepare_model_base (self ):
217+ try :
218+ with open (self ._cache_path , 'rb' ) as file :
219+ metadata = pickle .load (file = file )
220+ self ._cached = True
221+ return declarative_base (cls = Base , bind = self ._engine , metadata = metadata )
222+ except FileNotFoundError :
223+ metadata = sqlalchemy .MetaData (schema = self ._schema )
224+ self ._cached = False
225+ return automap_base (cls = Base , metadata = metadata )
226+
227+ def _prepare_automap_models (self ):
228+ self ._model_base .prepare (self ._engine )
229+ if not self ._cache_path : return
230+ try :
231+ with open (self ._cache_path , 'wb' ) as file :
232+ pickle .dump (self ._model_base .metadata , file )
233+ except FileNotFoundError :
234+ warnings .warn ('Unable to cache models which may lead to degraded performance.' , RuntimeWarning )
0 commit comments