-
Notifications
You must be signed in to change notification settings - Fork 92
print selene_sdk version, add config and model file to output, add ra… #191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,12 +9,14 @@ | |
from time import strftime | ||
import types | ||
import random | ||
import shutil, yaml | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from . import _is_lua_trained_model | ||
from . import instantiate | ||
from . import load_path | ||
|
||
|
||
def class_instantiate(classobj): | ||
|
@@ -111,6 +113,7 @@ def initialize_model(model_configs, train=True, lr=None): | |
|
||
module = None | ||
if os.path.isdir(import_model_from): | ||
import_model_from = import_model_from.rstrip(os.sep) | ||
module = module_from_dir(import_model_from) | ||
else: | ||
module = module_from_file(import_model_from) | ||
|
@@ -251,7 +254,7 @@ def execute(operations, configs, output_dir): | |
analyze_seqs.get_predictions(**predict_info) | ||
|
||
|
||
def parse_configs_and_run(configs, | ||
def parse_configs_and_run(configs_file, | ||
create_subdirectory=True, | ||
lr=None): | ||
""" | ||
|
@@ -260,9 +263,9 @@ def parse_configs_and_run(configs, | |
|
||
Parameters | ||
---------- | ||
configs : dict | ||
The dictionary of nested configuration parameters. Will look | ||
for the following top-level parameters: | ||
configs_file : str | ||
The configuration YAML file of nested configuration parameters. | ||
Will look for the following top-level parameters: | ||
|
||
* `ops`: A list of 1 or more of the values \ | ||
{"train", "evaluate", "analyze"}. The operations specified\ | ||
|
@@ -305,8 +308,18 @@ def parse_configs_and_run(configs, | |
to the dirs specified in each operation's configuration. | ||
|
||
""" | ||
if isinstance(configs_file, str): | ||
configs = load_path(configs_file, instantiate=False) | ||
else: | ||
ygliu2016 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
configs = configs_file | ||
operations = configs["ops"] | ||
|
||
#print selene_sdk version | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. remove commented out print statement |
||
if "selene_sdk_version" not in configs: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe simplify this to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure about name "selene_sdk_version" vs "version". May want to keep the longer version at this time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK I thought about it more, and I think it's better to have this be automatically populated? Let's just remove lines 322 and 323 outright. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please adjust based on comments above, I think it makes sense to just auto-populate since almost no one will try to specify it manually in their config |
||
from selene_sdk import version | ||
configs["selene_sdk_version"] = version.__version__ | ||
print("Selene_sdk Version = {}".format(version.__version__)) | ||
ygliu2016 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
||
if "train" in operations and "lr" not in configs and lr != None: | ||
configs["lr"] = float(lr) | ||
elif "train" in operations and "lr" in configs and lr != None: | ||
|
@@ -331,8 +344,9 @@ def parse_configs_and_run(configs, | |
if "create_subdirectory" in configs: | ||
create_subdirectory = configs["create_subdirectory"] | ||
if create_subdirectory: | ||
rand_str = str(random.random())[2:] | ||
current_run_output_dir = os.path.join( | ||
current_run_output_dir, strftime("%Y-%m-%d-%H-%M-%S")) | ||
current_run_output_dir, '{}-{}'.format(strftime("%Y-%m-%d-%H-%M-%S"), rand_str)) | ||
os.makedirs(current_run_output_dir) | ||
print("Outputs and logs saved to {0}".format( | ||
current_run_output_dir)) | ||
|
@@ -343,9 +357,29 @@ def parse_configs_and_run(configs, | |
np.random.seed(seed) | ||
torch.manual_seed(seed) | ||
torch.cuda.manual_seed_all(seed) | ||
#torch.backends.cudnn.deterministic = True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are these commented out? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is your code, which is not included in the main repository of selene. I am not sure if the code is necessary and thus copied here but commented them. It seems to me that with a seed, there is already deterministic in it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification. Please un-comment this as we need it for ensuring all aspects are deterministic. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same with the commented out line in 365, if that is my code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you remove the commented out lines in 364 and 365? |
||
#torch.backends.cudnn.benchmark = False | ||
print("Setting random seed = {0}".format(seed)) | ||
else: | ||
print("Warning: no random seed specified in config file. " | ||
"Using a random seed ensures results are reproducible.") | ||
|
||
if current_run_output_dir: | ||
# write configs to output directory | ||
if isinstance(configs_file, str): | ||
config_out = '{0}_lr{1}.yml'.format( | ||
os.path.basename(configs_file)[:-4], configs['lr']) | ||
shutil.copyfile(configs_file, | ||
os.path.join(current_run_output_dir, config_out)) | ||
else: | ||
with open('{}/{}'.format(current_run_output_dir,'configs.yaml'), 'w') as f: | ||
yaml.dump(configs, f, default_flow_style=None) | ||
# copy model file or directory to output | ||
model_input = configs['model']['path'] | ||
if os.path.isdir(model_input): # copy the directory | ||
shutil.copytree (model_input, os.path.join(current_run_output_dir, os.path.basename(import_model_from)), dirs_exist_ok=True) | ||
ygliu2016 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
else: | ||
shutil.copy (model_input, current_run_output_dir) | ||
ygliu2016 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
|
||
execute(operations, configs, current_run_output_dir) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -47,6 +47,6 @@ train_model: !obj:selene_sdk.TrainModel { | |
} | ||
random_seed: 1447 | ||
output_dir: ./training_outputs | ||
create_subdirectory: False | ||
create_subdirectory: True | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I assume this was changed because of testing - can this be changed back since it's not relevant to the PR? |
||
load_test_set: False | ||
... |
Uh oh!
There was an error while loading. Please reload this page.