from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
[docs]
class OrcaBlock(ABC):
@abstractmethod
[docs]
def block_type(self) -> BlockType:
pass
@dataclass
@dataclass
[docs]
class Atom:
[docs]
symbol: str = field(default=None)
[docs]
embedding_potential: bool = False
[docs]
is_frozen: bool = False # Not applied to anything but cartesian
[docs]
isotope: float | None = None
[docs]
nuclear_charge: float | None = None
[docs]
fragment_number: int | None = None
[docs]
point_charge: float | None = None
[docs]
bond_atom: int | None = None # Index of bonded atom (for internal coordinates)
[docs]
bond_length: float | None = None # Bond length (for internal coordinates)
[docs]
angle_atom: int | None = None # Index of angle atom (for internal coordinates)
[docs]
angle: float | None = None # Bond angle (for internal coordinates)
[docs]
dihedral_atom: int | None = None # Index of dihedral atom (for internal coordinates)
[docs]
dihedral: float | None = None # Dihedral angle (for internal coordinates)
[docs]
is_frozen_x: bool = False # Cartesian only
[docs]
is_frozen_y: bool = False # Cartesian only
[docs]
is_frozen_z: bool = False # Cartesian only
[docs]
def __post_init__(self):
if self.point_charge is not None:
self.symbol = "Q"
elif self.symbol is None:
msg = "Atom symbol is required unless it's a point charge."
raise ValueError(msg)
@dataclass
[docs]
class Coords:
[docs]
atoms: list[Atom] = field(default_factory=list)
@dataclass
@dataclass
[docs]
class GeomBlock(OrcaBlock):
[docs]
bonds: list[GeomScan] = field(default_factory=list)
[docs]
dihedrals: list[GeomScan] = field(default_factory=list)
[docs]
angles: list[GeomScan] = field(default_factory=list)
[docs]
def block_type(self) -> BlockType:
return BlockType.GEOM
@dataclass
[docs]
class LBFGSSettings:
[docs]
reparam_on_restart: bool = False
[docs]
precondition: bool = True
[docs]
restart_on_maxmove: bool = True
@dataclass
@dataclass
[docs]
class ReparamSettings:
[docs]
def __post_init__(self):
valid_interps = {"linear", "cubic"}
if self.interp.lower() not in valid_interps:
msg = f"Interp must be one of {valid_interps}, got '{self.interp}'"
raise ValueError(msg)
@dataclass
[docs]
class ConvTolSettings:
[docs]
maxf_ci: float = 0.0005
[docs]
rmsf_ci: float = 0.0003
[docs]
turn_on_ci: float = 0.02
@dataclass
@dataclass
[docs]
class OptimSettings:
[docs]
def __post_init__(self):
valid_methods = {"LBFGS", "VPO", "FIRE"}
if self.method.upper() not in valid_methods:
msg = f"Method must be one of {valid_methods}, got '{self.method}'"
raise ValueError(msg)
@dataclass
[docs]
class FreeEndSettings:
[docs]
def __post_init__(self):
valid_opt_types = {"PERP", "CONTOUR", "FULL"}
if self.opt_type.upper() not in valid_opt_types:
msg = f"opt_type must be one of {valid_opt_types}, got '{self.opt_type}'"
raise ValueError(msg)
@dataclass
[docs]
class ZoomSettings:
[docs]
tol_turn_on: float = 0.0
[docs]
interpolation: str = "linear"
[docs]
printfulltrj: bool = True
[docs]
def __post_init__(self):
valid_interpolations = {"linear", "cubic"}
if self.interpolation.lower() not in valid_interpolations:
msg = (
f"interpolation must be one of {valid_interpolations},"
" got '{self.interpolation}'"
)
raise ValueError(msg)
@dataclass
[docs]
class SpringSettings:
[docs]
spring_kind: str = "image"
[docs]
energy_weighted: bool = True
[docs]
def __post_init__(self):
valid_springkinds = {"image", "dof", "ideal"}
valid_perpsprings = {"no", "cos", "tan", "cosTan", "DNEB"}
if self.spring_kind.lower() not in valid_springkinds:
msg = (
f"spring_kind must be one of {valid_springkinds},"
" got '{self.spring_kind}'"
)
raise ValueError(msg)
if self.perpspring.lower() not in valid_perpsprings:
msg = (
f"perpstring must be one of {valid_perpsprings},"
" got '{self.perpspring}'"
)
raise ValueError(msg)
@dataclass
[docs]
class RestartSettings:
[docs]
gbw_basename: str = None
[docs]
def __post_init__(self):
if self.gbw_basename and self.allxyz:
msg = "Only one of gbw_basename or allxyz should be provided."
raise ValueError(msg)
@dataclass
[docs]
class TSGuessSettings:
[docs]
def __post_init__(self):
if self.xyz_struct and self.pdb_struct:
msg = "Only one of xyz_struct or pdb_struct should be provided."
raise ValueError(msg)
@dataclass
[docs]
class FixCenterSettings:
[docs]
remove_extern_force: bool = True
@dataclass
[docs]
class NebBlock(OrcaBlock):
[docs]
convtype: str = "CIONLY"
[docs]
quatern: str = "ALWAYS"
[docs]
climbingimage: bool = True
[docs]
check_scf_conv: bool = True
[docs]
npts_interpol: int = 10
[docs]
interpolation: str = "IDPP"
[docs]
tangent: str = "IMPROVED"
[docs]
lbfgs_settings: LBFGSSettings = field(default_factory=LBFGSSettings)
[docs]
fire_settings: FIRESettings = field(default_factory=FIRESettings)
[docs]
reparam_settings: ReparamSettings = field(default_factory=ReparamSettings)
[docs]
idpp_settings: IDPPSettings = field(default_factory=IDPPSettings)
[docs]
zoom_settings: ZoomSettings = field(default_factory=ZoomSettings)
[docs]
optim_settings: OptimSettings = field(default_factory=OptimSettings)
[docs]
convtol_settings: ConvTolSettings = field(default_factory=ConvTolSettings)
[docs]
free_end_settings: FreeEndSettings = field(default_factory=FreeEndSettings)
[docs]
spring_settings: SpringSettings = field(default_factory=SpringSettings)
[docs]
restart_settings: RestartSettings = field(default_factory=RestartSettings)
[docs]
tsguess_settings: TSGuessSettings = field(default_factory=TSGuessSettings)
[docs]
fix_center_settings: FixCenterSettings = field(default_factory=FixCenterSettings)
[docs]
def __post_init__(self):
valid_convtypes = {"all", "cionly"}
valid_quaterns = {"no", "startonly", "always"}
valid_tangents = {"improved", "original"}
valid_interpolations = {"IDPP", "LINEAR", "XTB1TS", "XTB1", "XTB2TS", "XTB2"}
if self.convtype.lower() not in valid_convtypes:
msg = (
f"Convergence type must be one of {valid_convtypes},"
" got '{self.convtype}'"
)
raise ValueError(msg)
if self.quatern.lower() not in valid_quaterns:
msg = f"quatern must be one of {valid_quaterns}, got '{{self.quatern}}'"
raise ValueError(msg)
if self.tangent.lower() not in valid_tangents:
msg = f"tangent must be one of {valid_tangents}, got '{{self.tangent}}'"
raise ValueError(msg)
if self.interpolation.upper() not in valid_interpolations:
msg = (
f"interpolation must be one of {valid_interpolations},"
" got '{self.interpolation}'"
)
raise ValueError(msg)
[docs]
def block_type(self) -> BlockType:
return BlockType.NEB
@dataclass
[docs]
class OrcaConfig:
[docs]
blocks: dict[BlockType, OrcaBlock] = field(default_factory=dict)
[docs]
def add_block(self, block: OrcaBlock):
self.blocks[block.block_type()] = block
# @dataclass
# class OrcaConfig:
# coords: Coords
# kwlines: List[KWLine] = None
# orca_geom: Optional[OrcaGeom] = None