Skip to content

Commit e7c42bd

Browse files
committed
Combine dataclass fn exec calls
It was mentioned that a low-hanging fruit to speed up startup time for dataclasses is to merge exec calls into 1. This version speeds up definition 1.26 (old time / new time) using this benchmark: https://gist.github.com/ssweber/34222fd2e770a72147903ab2706f98d7
1 parent d73c12b commit e7c42bd

File tree

1 file changed

+55
-54
lines changed

1 file changed

+55
-54
lines changed

Lib/dataclasses.py

Lines changed: 55 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -446,33 +446,39 @@ def _tuple_str(obj_name, fields):
446446
return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
447447

448448

449-
def _create_fn(name, args, body, *, globals=None, locals=None,
450-
return_type=MISSING):
449+
def _create_fn_def(name, args, body, *, locals=None, return_type=MISSING):
451450
# Note that we may mutate locals. Callers beware!
452451
# The only callers are internal to this module, so no
453452
# worries about external callers.
454453
if locals is None:
455454
locals = {}
456455
return_annotation = ''
457456
if return_type is not MISSING:
458-
locals['__dataclass_return_type__'] = return_type
459-
return_annotation = '->__dataclass_return_type__'
457+
return_name = name.replace("__", "")
458+
locals[f'__dataclass_{return_name}_return_type__'] = return_type
459+
return_annotation = f'->__dataclass_{return_name}_return_type__'
460460
args = ','.join(args)
461461
body = '\n'.join(f' {b}' for b in body)
462462

463463
# Compute the text of the entire function.
464-
txt = f' def {name}({args}){return_annotation}:\n{body}'
464+
txt = f'def {name}({args}){return_annotation}:\n{body}'
465465

466+
return (name, txt, locals)
467+
468+
def _exec_fn_defs(fn_defs, globals=None):
466469
# Free variables in exec are resolved in the global namespace.
467470
# The global namespace we have is user-provided, so we can't modify it for
468471
# our purposes. So we put the things we need into locals and introduce a
469472
# scope to allow the function we're creating to close over them.
470-
local_vars = ', '.join(locals.keys())
471-
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
473+
locals_dict = {k: v for _, _, locals_ in fn_defs
474+
for k, v in locals_.items()}
475+
local_vars = ', '.join(locals_dict.keys())
476+
fn_names = ", ".join(name for name, _, _ in fn_defs)
477+
txt = "\n".join(f" {txt}" for _, txt, _ in fn_defs)
478+
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {fn_names}"
472479
ns = {}
473480
exec(txt, globals, ns)
474-
return ns['__create_fn__'](**locals)
475-
481+
return ns['__create_fn__'](**locals_dict)
476482

477483
def _field_assign(frozen, name, value, self_name):
478484
# If we're a frozen class, then assign to our fields in __init__
@@ -566,7 +572,7 @@ def _init_param(f):
566572

567573

568574
def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
569-
self_name, globals, slots):
575+
self_name, slots):
570576
# fields contains both real fields and InitVar pseudo-fields.
571577

572578
# Make sure we don't have fields without defaults following fields
@@ -616,68 +622,61 @@ def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
616622
# (instead of just concatenting the lists together).
617623
_init_params += ['*']
618624
_init_params += [_init_param(f) for f in kw_only_fields]
619-
return _create_fn('__init__',
625+
return _create_fn_def('__init__',
620626
[self_name] + _init_params,
621627
body_lines,
622628
locals=locals,
623-
globals=globals,
624629
return_type=None)
625630

626631

627-
def _repr_fn(fields, globals):
628-
fn = _create_fn('__repr__',
632+
def _repr_fn(fields):
633+
return _create_fn_def('__repr__',
629634
('self',),
630635
['return f"{self.__class__.__qualname__}(' +
631636
', '.join([f"{f.name}={{self.{f.name}!r}}"
632637
for f in fields]) +
633-
')"'],
634-
globals=globals)
635-
return _recursive_repr(fn)
638+
')"'],)
636639

637640

638-
def _frozen_get_del_attr(cls, fields, globals):
641+
def _frozen_get_del_attr(cls, fields):
639642
locals = {'cls': cls,
640643
'FrozenInstanceError': FrozenInstanceError}
641644
condition = 'type(self) is cls'
642645
if fields:
643646
condition += ' or name in {' + ', '.join(repr(f.name) for f in fields) + '}'
644-
return (_create_fn('__setattr__',
647+
return (_create_fn_def('__setattr__',
645648
('self', 'name', 'value'),
646649
(f'if {condition}:',
647650
' raise FrozenInstanceError(f"cannot assign to field {name!r}")',
648651
f'super(cls, self).__setattr__(name, value)'),
649-
locals=locals,
650-
globals=globals),
651-
_create_fn('__delattr__',
652+
locals=locals),
653+
_create_fn_def('__delattr__',
652654
('self', 'name'),
653655
(f'if {condition}:',
654656
' raise FrozenInstanceError(f"cannot delete field {name!r}")',
655657
f'super(cls, self).__delattr__(name)'),
656-
locals=locals,
657-
globals=globals),
658+
locals=locals),
658659
)
659660

660661

661-
def _cmp_fn(name, op, self_tuple, other_tuple, globals):
662+
def _cmp_fn(name, op, self_tuple, other_tuple):
662663
# Create a comparison function. If the fields in the object are
663664
# named 'x' and 'y', then self_tuple is the string
664665
# '(self.x,self.y)' and other_tuple is the string
665666
# '(other.x,other.y)'.
666667

667-
return _create_fn(name,
668+
return _create_fn_def(name,
668669
('self', 'other'),
669670
[ 'if other.__class__ is self.__class__:',
670671
f' return {self_tuple}{op}{other_tuple}',
671-
'return NotImplemented'],
672-
globals=globals)
672+
'return NotImplemented'],)
673673

674674

675-
def _hash_fn(fields, globals):
675+
def _hash_fn(fields):
676676
self_tuple = _tuple_str('self', fields)
677-
return _create_fn('__hash__',
677+
return _create_fn_def('__hash__',
678678
('self',),
679-
[f'return hash({self_tuple})'],
680-
globals=globals)
679+
[f'return hash({self_tuple})'],)
681680

682681

683682
def _is_classvar(a_type, typing):
@@ -925,6 +924,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
925924
# derived class fields overwrite base class fields, but the order
926925
# is defined by the base class, which is found first.
927926
fields = {}
927+
fn_defs = []
928928

929929
if cls.__module__ in sys.modules:
930930
globals = sys.modules[cls.__module__].__dict__
@@ -1059,8 +1059,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10591059
# Does this class have a post-init function?
10601060
has_post_init = hasattr(cls, _POST_INIT_NAME)
10611061

1062-
_set_new_attribute(cls, '__init__',
1063-
_init_fn(all_init_fields,
1062+
fn_defs.append(_init_fn(all_init_fields,
10641063
std_init_fields,
10651064
kw_only_init_fields,
10661065
frozen,
@@ -1070,7 +1069,6 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10701069
# if possible.
10711070
'__dataclass_self__' if 'self' in fields
10721071
else 'self',
1073-
globals,
10741072
slots,
10751073
))
10761074
_set_new_attribute(cls, '__replace__', _replace)
@@ -1081,7 +1079,7 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10811079

10821080
if repr:
10831081
flds = [f for f in field_list if f.repr]
1084-
_set_new_attribute(cls, '__repr__', _repr_fn(flds, globals))
1082+
fn_defs.append(_repr_fn(flds))
10851083

10861084
if eq:
10871085
# Create __eq__ method. There's no need for a __ne__ method,
@@ -1092,33 +1090,36 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
10921090
body = [f'if other.__class__ is self.__class__:',
10931091
f' return {field_comparisons}',
10941092
f'return NotImplemented']
1095-
func = _create_fn('__eq__',
1096-
('self', 'other'),
1097-
body,
1098-
globals=globals)
1099-
_set_new_attribute(cls, '__eq__', func)
1093+
fn_defs.append(_create_fn_def('__eq__',('self', 'other'), body,))
11001094

11011095
if order:
11021096
# Create and set the ordering methods.
11031097
flds = [f for f in field_list if f.compare]
11041098
self_tuple = _tuple_str('self', flds)
11051099
other_tuple = _tuple_str('other', flds)
1106-
for name, op in [('__lt__', '<'),
1107-
('__le__', '<='),
1108-
('__gt__', '>'),
1109-
('__ge__', '>='),
1110-
]:
1111-
if _set_new_attribute(cls, name,
1112-
_cmp_fn(name, op, self_tuple, other_tuple,
1113-
globals=globals)):
1100+
order_flds = {'__lt__' : '<',
1101+
'__le__' : '<=',
1102+
'__gt__' : '>',
1103+
'__ge__' : '>=',
1104+
}
1105+
for name, op in order_flds.items():
1106+
fn_defs.append(_cmp_fn(name, op, self_tuple, other_tuple))
1107+
1108+
if frozen:
1109+
fn_defs.extend(_frozen_get_del_attr(cls, field_list))
1110+
1111+
functions_objects = _exec_fn_defs(fn_defs, globals=globals)
1112+
for fn in functions_objects:
1113+
name = fn.__name__
1114+
if name == '__repr__':
1115+
fn = _recursive_repr(fn)
1116+
if _set_new_attribute(cls, name, fn):
1117+
if order and name in order_flds:
11141118
raise TypeError(f'Cannot overwrite attribute {name} '
11151119
f'in class {cls.__name__}. Consider using '
11161120
'functools.total_ordering')
1117-
1118-
if frozen:
1119-
for fn in _frozen_get_del_attr(cls, field_list, globals):
1120-
if _set_new_attribute(cls, fn.__name__, fn):
1121-
raise TypeError(f'Cannot overwrite attribute {fn.__name__} '
1121+
elif frozen:
1122+
raise TypeError(f'Cannot overwrite attribute {name} '
11221123
f'in class {cls.__name__}')
11231124

11241125
# Decide if/how we're going to create a hash function.

0 commit comments

Comments
 (0)