Source code for aiida_atoms.tracker

"""
Track changes of an atom
"""
from functools import wraps
from typing import Union
import warnings

from ase import Atoms
from ase.build import make_supercell
from ase.build import sort as ase_sort
import numpy as np
from packaging import version

from aiida import __version__ as AIIDA_VERSION
from aiida import orm
from aiida.engine import calcfunction


[docs]def dummy_function(*args, **kwargs): """ A dummy function with ``*args`` and ``**kwargs`` Need to trigger dynamic namespace in aiida-core >= 2.3.0 """ _ = args _ = kwargs
[docs]def wraps_ase_out_of_place(func): """Wraps an ASE out of place operation""" @wraps(func) def inner(tracker, *args, **kwargs): """Inner function wrapped""" atoms = tracker.node.get_ase() aiida_kwargs = {key: to_aiida_rep(value) for key, value in kwargs.items()} for i, arg in enumerate(args): aiida_kwargs[f"arg_{i:02d}"] = to_aiida_rep(arg) # Create a dummy connection between the input the output using @calcfunction new_atoms = func(atoms, *args, **kwargs) @wraps(func) def _transform(node, **dummy_args): # pylint:disable=unused-argument return orm.StructureData(ase=new_atoms) if version.parse(AIIDA_VERSION) >= version.parse("2.3.0"): _transform.__wrapped__ = dummy_function transform = calcfunction(_transform) if tracker.track_provenance: node = transform(tracker.node, **aiida_kwargs) else: node = _transform(tracker.node, **aiida_kwargs) return AtomsTracker(obj=node, atoms=new_atoms) return inner
wop = wraps_ase_out_of_place
[docs]def wraps_ase_inplace(func): """Wraps an ASE in place operation""" @wraps(func) def inner(tracker, *args, **kwargs): """Inner function wrapped""" atoms = tracker.atoms aiida_kwargs = {key: to_aiida_rep(value) for key, value in kwargs.items()} for i, arg in enumerate(args): aiida_kwargs[f"arg_{i:02d}"] = to_aiida_rep(arg) retobj = [] @wraps(func) def _transform(node, **dummy_args): # pylint:disable=unused-argument # func is an inplace operation retobj.append(func(atoms, *args, **kwargs)) return orm.StructureData(ase=atoms) if version.parse(AIIDA_VERSION) >= version.parse("2.3.0"): _transform.__wrapped__ = dummy_function transform = calcfunction(_transform) if tracker.track_provenance: # Call the wrapped function if we indeed tracking the provenance node = transform(tracker.node, **aiida_kwargs) else: node = _transform(tracker.node, **aiida_kwargs) # Update the current node tracker.node = node return retobj[0] return inner
[docs]def to_aiida_rep(pobj): """ Convert to AiiDA representation and serialization. The return object is not guaranteed to fully deserialize back to the input. A string representation is used as the fallback. """ if isinstance(pobj, dict): return orm.Dict(dict=pobj) if isinstance(pobj, list): return orm.List(list=pobj) if isinstance(pobj, tuple): return orm.List(list=list(pobj)) if isinstance(pobj, Atoms): return orm.StructureData(ase=pobj) if isinstance(pobj, float): return orm.Float(pobj) if isinstance(pobj, int): return orm.Int(pobj) if isinstance(pobj, str): return orm.Str(pobj) if isinstance(pobj, np.ndarray): data = orm.ArrayData() data.set_array("array", pobj) return data warnings.warn(f"Cannot serialise {pobj} - falling back to string representation.") return orm.Str(pobj)
[docs]class AtomsTracker: # pylint: disable=too-few-public-methods """Tracking changes of an atom""" def __init__( self, obj, atoms: Union[Atoms, None] = None, track=True, ): """Instantiate""" if isinstance(obj, Atoms): self.atoms = obj self.node = orm.StructureData(ase=obj) elif isinstance(obj, AtomsTracker): self.node = obj.node self.atoms = self.node.get_ase() else: self.node = obj self.atoms = self.node.get_ase() if atoms is None else atoms self.track_provenance = track
[docs] def __repr__(self) -> str: """Python representation""" string = f"AtomsTracker({self.atoms.__repr__()}, {self.node.__repr__()})" return string
sort = wraps_ase_out_of_place(ase_sort) make_supercell = wraps_ase_out_of_place(make_supercell) @property def label(self): """Label of the underlying node.""" return self.node.label @label.setter def label(self, value): """Set the label of the underlying node.""" self.node.label = value @property def description(self): """Description of the underlying node.""" return self.node.description @property def id(self): # pylint: disable=invalid-name """ID of the underlying node""" return self.node.id @property def uuid(self): """UUID of the underlying node""" return self.node.uuid @description.setter def description(self, value): """Set the description of the underlying node.""" self.node.description = value @property def base(self): """The `base` accessor for the underlying node.""" return self.node.base
[docs] def store_node(self, *args, **kwargs): """Store the underlying node""" self.node.store(*args, **kwargs)
[docs]def _populate_methods(): """Populate the methods for the `AtomsTracker` class""" methods_in_place = [ "set_cell", "set_positions", "set_pbc", "set_atomic_numbers", "set_chemical_symbols", "set_masses", "pop", "translate", "center", "set_center_of_mass", "rotate", "euler_rotate", "set_dihedral", "rotate_dihedral", "set_angle", "rattle", "set_distance", "set_scaled_positions", "wrap", "__delitem__", "__imul__", ] methods_out_of_place = ["repeat", "__getitem__", "__mul__"] for name in methods_in_place: setattr(AtomsTracker, name, wraps_ase_inplace(getattr(Atoms, name))) for name in methods_out_of_place: setattr(AtomsTracker, name, wraps_ase_out_of_place(getattr(Atoms, name)))
_populate_methods()