Source code for brainscore_vision.utils

"""
Provide generic helper classes.
"""

import copy


[docs]def fullname(obj): """ Resolve the full module-qualified name of an object. Typically used for logger naming. """ return obj.__module__ + "." + obj.__class__.__name__
[docs]def map_fields(obj, func): for field_name, field_value in vars(obj).items(): field_value = func(field_value) setattr(obj, field_name, field_value)
[docs]def combine_fields(objs, func): if len(objs) == 0: return objs fields = list(vars(objs[0]).keys()) field_values = {field_name: [] for field_name in fields} for obj in objs: for field_name in fields: field_value = getattr(obj, field_name) field_values[field_name].append(field_value) field_values = dict(map( lambda field_name_values: (field_name_values[0], func(field_name_values[1])), field_values.items())) ctr = objs[0].__class__ return ctr(**field_values)
[docs]def recursive_dict_merge(dict1, dict2): """ Merges dictionaries (of dictionaries). Preference is given to the second dict, i.e. if a key occurs in both dicts, the value from `dict2` is used. """ result = copy.deepcopy(dict1) for key in dict2: if key in dict1 and isinstance(dict1[key], dict) and isinstance(dict2[key], dict): result[key] = recursive_dict_merge(dict1[key], dict2[key]) else: result[key] = dict2[key] return result
[docs]class LazyLoad:
[docs] def __init__(self, load_fnc): self.load_fnc = load_fnc self.content = None
def __getattr__(self, name): self._ensure_loaded() return getattr(self.content, name) def __setattr__(self, key, value): if key not in ['content', 'load_fnc']: self._ensure_loaded() return setattr(self.content, key, value) return super(LazyLoad, self).__setattr__(key, value) def __getitem__(self, item): self._ensure_loaded() return self.content.__getitem__(item) def __setitem__(self, key, value): self._ensure_loaded() return self.content.__setitem__(key, value) def _ensure_loaded(self): if self.content is None: self.content = self.load_fnc()
[docs] def reload(self): self.content = self.load_fnc()
[docs] def __call__(self, *args, **kwargs): self._ensure_loaded() return self.content(*args, **kwargs)
def __len__(self): self._ensure_loaded() return len(self.content) @property def __class__(self): self._ensure_loaded() return self.content.__class__