Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions saicinpainting/training/data/aug.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from albumentations import DualIAATransform, to_tuple
import imgaug.augmenters as iaa
from albumentations import DualTransform, Affine, Perspective
from albumentations.core.utils import to_tuple

class IAAAffine2(DualIAATransform):
class IAAAffine2(DualTransform):
"""Place a regular grid of points on the input and randomly move the neighbourhood of these point around
via affine transformations.

Expand Down Expand Up @@ -39,7 +39,7 @@ def __init__(

@property
def processor(self):
return iaa.Affine(
return Affine(
self.scale,
self.translate_percent,
self.translate_px,
Expand All @@ -54,7 +54,7 @@ def get_transform_init_args_names(self):
return ("scale", "translate_percent", "translate_px", "rotate", "shear", "order", "cval", "mode")


class IAAPerspective2(DualIAATransform):
class IAAPerspective2(DualTransform):
"""Perform a random four point perspective transform of the input.

Note: This class introduce interpolation artifacts to mask if it has values other than {0;1}
Expand All @@ -78,7 +78,7 @@ def __init__(self, scale=(0.05, 0.1), keep_size=True, always_apply=False, p=0.5,

@property
def processor(self):
return iaa.PerspectiveTransform(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval)
return Perspective(self.scale, keep_size=self.keep_size, mode=self.mode, cval=self.cval)

def get_transform_init_args_names(self):
return ("scale", "keep_size")
2 changes: 1 addition & 1 deletion saicinpainting/training/trainers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def make_training_model(config):

def load_checkpoint(train_config, path, map_location='cuda', strict=True):
model: torch.nn.Module = make_training_model(train_config)
state = torch.load(path, map_location=map_location)
state = torch.load(path, map_location=map_location, weights_only=False)
model.load_state_dict(state['state_dict'], strict=strict)
model.on_load_checkpoint(state)
return model