Skip to content

Commit f7762fe

Browse files
committed
Consistency handling None / empty string inputs to norm / act create fns
1 parent dcfdba1 commit f7762fe

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

timm/layers/create_act.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,11 +132,13 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
132132
Fetching activation layers by name with this function allows export or torch script friendly
133133
functions to be returned dynamically based on current config.
134134
"""
135-
if not name:
135+
if name is None:
136136
return None
137137
if not isinstance(name, str):
138138
# callable, module, etc
139139
return name
140+
if not name:
141+
return None
140142
if not (is_no_jit() or is_exportable() or is_scriptable()):
141143
if name in _ACT_LAYER_ME:
142144
return _ACT_LAYER_ME[name]

timm/layers/create_norm.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,7 @@ def create_norm_layer(layer_name, num_features, **kwargs):
3434

3535

3636
def get_norm_layer(norm_layer):
37-
if not norm_layer:
38-
# None or '' should return None
37+
if norm_layer is None:
3938
return None
4039
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
4140
norm_kwargs = {}
@@ -46,6 +45,8 @@ def get_norm_layer(norm_layer):
4645
norm_layer = norm_layer.func
4746

4847
if isinstance(norm_layer, str):
48+
if not norm_layer:
49+
return None
4950
layer_name = norm_layer.replace('_', '')
5051
norm_layer = _NORM_MAP[layer_name]
5152
else:

timm/layers/create_norm_act.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=Tr
5050

5151

5252
def get_norm_act_layer(norm_layer, act_layer=None):
53+
if norm_layer is None:
54+
return None
5355
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
5456
assert act_layer is None or isinstance(act_layer, (type, str, types.FunctionType, functools.partial))
5557
norm_act_kwargs = {}
@@ -60,8 +62,10 @@ def get_norm_act_layer(norm_layer, act_layer=None):
6062
norm_layer = norm_layer.func
6163

6264
if isinstance(norm_layer, str):
65+
if not norm_layer:
66+
return None
6367
layer_name = norm_layer.replace('_', '').lower().split('-')[0]
64-
norm_act_layer = _NORM_ACT_MAP.get(layer_name, None)
68+
norm_act_layer = _NORM_ACT_MAP[layer_name]
6569
elif norm_layer in _NORM_ACT_TYPES:
6670
norm_act_layer = norm_layer
6771
elif isinstance(norm_layer, types.FunctionType):

0 commit comments

Comments
 (0)