Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 1 addition & 66 deletions molSimplify/Classes/globalvars.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,75 +338,10 @@
'y': 1, 'zr': 2, 'nb': 3, 'mo': 4, 'tc': 5, 'ru': 6, 'rh': 7, 'pd': 8, 'ag': 9, 'cd': 10,
'hf': 2, 'ta': 3, 'w': 4, 're': 5, 'os': 6, 'ir': 7, 'pt': 8, 'au': 9, 'hg': 10}

# Default spins for each d-electron count (metal/oxidation state-specific version in metal_ox_spinlist)
# Default spins for each d-electron count (make this metal/oxidation state specific)
defaultspins = {0: '1', 1: '2', 2: '3', 3: '4', 4: '5',
5: '6', 6: '5', 7: '4', 8: '3', 9: '2', 10: '1'}

# d-electron counts for each transition metal and oxidation state
# ground states determined from NIST Atomic Spectra Database https://physics.nist.gov/PhysRefData/ASD/levels_form.html, accessed 2025
metal_ox_dlist = {'Ag(1)': 10, 'Ag(2)': 9, 'Ag(3)': 8,
'Au(0)': 10, 'Au(1)': 10, 'Au(2)': 9, 'Au(3)': 8, 'Au(4)': 7, 'Au(5)': 6,
'Cd(1)': 10, 'Cd(2)': 10,
'Co(0)': 7, 'Co(1)': 8, 'Co(2)': 7, 'Co(3)': 6, 'Co(4)': 5, 'Co(5)': 4,
'Cr(0)': 5, 'Cr(1)': 5, 'Cr(2)': 4, 'Cr(3)': 3, 'Cr(4)': 2, 'Cr(5)': 1, 'Cr(6)': 0, 'Cr(7)': 0,
'Cu(0)': 10, 'Cu(1)': 10, 'Cu(2)': 9, 'Cu(3)': 8, 'Cu(4)': 7,
'Fe(0)': 6, 'Fe(1)': 6, 'Fe(2)': 6, 'Fe(3)': 5, 'Fe(4)': 4, 'Fe(5)': 3,
'Hf(0)': 2, 'Hf(2)': 2, 'Hf(3)': 1, 'Hf(4)': 0,
'Hg(1)': 10, 'Hg(2)': 10,
'Ir(0)': 7, 'Ir(1)': 7, 'Ir(2)': 7, 'Ir(3)': 6, 'Ir(4)': 5, 'Ir(5)': 4,
'Mn(0)': 5, 'Mn(1)': 5, 'Mn(2)': 5, 'Mn(3)': 4, 'Mn(4)': 3, 'Mn(5)': 2, 'Mn(6)': 1, 'Mn(7)': 0,
'Mo(0)': 5, 'Mo(1)': 5, 'Mo(2)': 4, 'Mo(3)': 3, 'Mo(4)': 2, 'Mo(5)': 1, 'Mo(6)': 0, 'Mo(7)': 0,
'Nb(0)': 4, 'Nb(1)': 4, 'Nb(2)': 3, 'Nb(3)': 2, 'Nb(4)': 1, 'Nb(5)': 0,
'Ni(0)': 8, 'Ni(1)': 9, 'Ni(2)': 8, 'Ni(3)': 7, 'Ni(4)': 6, 'Ni(5)': 5,
'Os(0)': 6, 'Os(1)': 6, 'Os(2)': 6, 'Os(3)': 5, 'Os(4)': 4, 'Os(5)': 3, 'Os(6)': 2, 'Os(8)': 0,
'Pd(0)': 10, 'Pd(1)': 9, 'Pd(2)': 8, 'Pd(3)': 7, 'Pd(4)': 6,
'Pt(0)': 9, 'Pt(1)': 9, 'Pt(2)': 8, 'Pt(3)': 7, 'Pt(4)': 6, 'Pt(6)': 4,
'Re(0)': 5, 'Re(1)': 5, 'Re(2)': 5, 'Re(3)': 4, 'Re(4)': 3, 'Re(5)': 2, 'Re(6)': 1, 'Re(7)': 0,
'Rh(0)': 8, 'Rh(1)': 8, 'Rh(2)': 7, 'Rh(3)': 6, 'Rh(4)': 5, 'Rh(5)': 4,
'Ru(0)': 7, 'Ru(1)': 7, 'Ru(2)': 6, 'Ru(3)': 5, 'Ru(4)': 4, 'Ru(5)': 3, 'Ru(6)': 2,
'Sc(1)': 1, 'Sc(2)': 1, 'Sc(3)': 0,
'Ta(0)': 3, 'Ta(1)': 3, 'Ta(2)': 3, 'Ta(3)': 2, 'Ta(4)': 1, 'Ta(5)': 0,
'Tc(0)': 5, 'Tc(1)': 5, 'Tc(2)': 5, 'Tc(3)': 4, 'Tc(4)': 3, 'Tc(5)': 2, 'Tc(6)': 1, 'Tc(7)': 0,
'Ti(0)': 2, 'Ti(1)': 2, 'Ti(2)': 2, 'Ti(3)': 1, 'Ti(4)': 0,
'V(0)': 3, 'V(1)': 4, 'V(2)': 3, 'V(3)': 2, 'V(4)': 1, 'V(5)': 0,
'W(0)': 4, 'W(1)': 4, 'W(2)': 4, 'W(3)': 3, 'W(4)': 2, 'W(5)': 1, 'W(6)': 0,
'Y(2)': 1, 'Y(3)': 0, 'Y(4)': 0, 'Y(5)': 0,
'Zn(0)': 10, 'Zn(1)': 10, 'Zn(2)': 10, 'Zn(4)': 8, 'Zn(5)': 7,
'Zr(0)': 2, 'Zr(2)': 2, 'Zr(3)': 1, 'Zr(4)': 0}

# allowed spins (low, intermediate, high) for each transition metal and oxidation state
# assumes only 3d metals exhibit high spin states, all exhibit low and intermediate spin states
# determined from NIST Atomic Spectra Database https://physics.nist.gov/PhysRefData/ASD/levels_form.html, accessed 2025
metal_ox_spinlist = {'Ag(1)': [1, 3], 'Ag(2)': [2, 4], 'Ag(3)': [1, 3],
'Au(0)': [2, 4], 'Au(1)': [1, 3], 'Au(2)': [2, 4], 'Au(3)': [1, 3], 'Au(4)': [2, 4], 'Au(5)': [1, 3],
'Cd(1)': [2, 4], 'Cd(2)': [1, 3],
'Co(0)': [2, 4, 6], 'Co(1)': [1, 3, 5], 'Co(2)': [2, 4, 6], 'Co(3)': [1, 3, 5], 'Co(4)': [2, 4, 6], 'Co(5)': [1, 3, 5],
'Cr(0)': [1, 3, 5], 'Cr(1)': [2, 4, 6], 'Cr(2)': [1, 3, 5], 'Cr(3)': [2, 4, 6], 'Cr(4)': [1, 3, 5], 'Cr(5)': [2, 4, 6], 'Cr(6)': [1, 3, 5], 'Cr(7)': [2, 4, 6],
'Cu(0)': [2, 4, 6], 'Cu(1)': [1, 3, 5], 'Cu(2)': [2, 4, 6], 'Cu(3)': [1, 3, 5], 'Cu(4)': [2, 4, 6],
'Fe(0)': [1, 3, 5], 'Fe(1)': [2, 4, 6], 'Fe(2)': [1, 3, 5], 'Fe(3)': [2, 4, 6], 'Fe(4)': [1, 3, 5], 'Fe(5)': [2, 4, 6],
'Hf(0)': [1, 3], 'Hf(2)': [1, 3], 'Hf(3)': [2, 4], 'Hf(4)': [1, 3],
'Hg(1)': [2, 4], 'Hg(2)': [1, 3],
'Ir(0)': [2, 4], 'Ir(1)': [1, 3], 'Ir(2)': [2, 4], 'Ir(3)': [1, 3], 'Ir(4)': [2, 4], 'Ir(5)': [1, 3],
'Mn(0)': [2, 4, 6], 'Mn(1)': [1, 3, 5], 'Mn(2)': [2, 4, 6], 'Mn(3)': [1, 3, 5], 'Mn(4)': [2, 4, 6], 'Mn(5)': [1, 3, 5], 'Mn(6)': [2, 4, 6], 'Mn(7)': [1, 3, 5],
'Mo(0)': [1, 3], 'Mo(1)': [2, 4], 'Mo(2)': [1, 3], 'Mo(3)': [2, 4], 'Mo(4)': [1, 3], 'Mo(5)': [2, 4], 'Mo(6)': [1, 3], 'Mo(7)': [2, 4],
'Nb(0)': [2, 4], 'Nb(1)': [1, 3], 'Nb(2)': [2, 4], 'Nb(3)': [1, 3], 'Nb(4)': [2, 4], 'Nb(5)': [1, 3],
'Ni(0)': [1, 3, 5],'Ni(1)': [2, 4, 6], 'Ni(2)': [1, 3, 5], 'Ni(3)': [2, 4, 6], 'Ni(4)': [1, 3, 5], 'Ni(5)': [2, 4, 6],
'Os(0)': [1, 3], 'Os(1)': [2, 4], 'Os(2)': [1, 3], 'Os(3)': [2, 4], 'Os(4)': [1, 3], 'Os(5)': [2, 4], 'Os(6)': [1, 3], 'Os(8)': [1, 3],
'Pd(0)': [1, 3], 'Pd(1)': [2, 4], 'Pd(2)': [1, 3], 'Pd(3)': [2, 4], 'Pd(4)': [1, 3],
'Pt(0)': [1, 3], 'Pt(1)': [2, 4], 'Pt(2)': [1, 3], 'Pt(3)': [2, 4], 'Pt(4)': [1, 3], 'Pt(6)': [1, 3],
'Re(0)': [2, 4], 'Re(1)': [1, 3], 'Re(2)': [2, 4], 'Re(3)': [1, 3], 'Re(4)': [2, 4], 'Re(5)': [1, 3], 'Re(6)': [2, 4], 'Re(7)': [1, 3],
'Rh(0)': [2, 4], 'Rh(1)': [1, 3], 'Rh(2)': [2, 4], 'Rh(3)': [1, 3], 'Rh(4)': [2, 4], 'Rh(5)': [1, 3],
'Ru(0)': [1, 3], 'Ru(1)': [2, 4], 'Ru(2)': [1, 3], 'Ru(3)': [2, 4], 'Ru(4)': [1, 3], 'Ru(5)': [2, 4], 'Ru(6)': [1, 3],
'Sc(1)': [1, 3, 5], 'Sc(2)': [2, 4, 6], 'Sc(3)': [1, 3, 5],
'Ta(0)': [2, 4], 'Ta(1)': [1, 3], 'Ta(2)': [2, 4], 'Ta(3)': [1, 3], 'Ta(4)': [2, 4], 'Ta(5)': [1, 3],
'Tc(0)': [2, 4], 'Tc(1)': [1, 3], 'Tc(2)': [2, 4],'Tc(3)': [1, 3], 'Tc(4)': [2, 4], 'Tc(5)': [1, 3], 'Tc(6)': [2, 4], 'Tc(7)': [1, 3],
'Ti(0)': [1, 3, 5], 'Ti(1)': [2, 4, 6], 'Ti(2)': [1, 3, 5], 'Ti(3)': [2, 4, 6], 'Ti(4)': [1, 3, 5],
'V(0)': [2, 4, 6], 'V(1)': [1, 3, 5], 'V(2)': [2, 4, 6], 'V(3)': [1, 3, 5], 'V(4)': [2, 4, 6], 'V(5)': [1, 3, 5],
'W(0)': [1, 3], 'W(1)': [2, 4], 'W(2)': [1, 3], 'W(3)': [2, 4], 'W(4)': [1, 3], 'W(5)': [2, 4], 'W(6)': [1, 3],
'Y(2)': [2, 4], 'Y(3)': [1, 3], 'Y(4)': [2, 4], 'Y(5)': [1, 3],
'Zn(0)': [1, 3, 5], 'Zn(1)': [2, 4, 6], 'Zn(2)': [1, 3, 5], 'Zn(4)': [1, 3, 5], 'Zn(5)': [2, 4, 6],
'Zr(0)': [1, 3], 'Zr(2)': [1, 3], 'Zr(3)': [2, 4], 'Zr(4)': [1, 3]}

# Elements sorted by atomic number
elementsbynum = ['H', 'He',
'Li', 'Be', 'B', 'C', 'N', 'O', 'F', 'Ne',
Expand Down
201 changes: 200 additions & 1 deletion molSimplify/Scripts/enhanced_structgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from molSimplify.utils.openbabel_helpers import *
from molSimplify.Scripts.io import lig_load
from molSimplify.Classes.globalvars import vdwrad
from molSimplify.Classes.globalvars import globalvars
from molSimplify.Scripts.io import getlicores, lig_load_safe, parse_bracketed_list

from typing import Any, List, Dict, Tuple, Union, Optional
Expand Down Expand Up @@ -206,6 +207,10 @@
# Accumulate ALL ligands' haptic groups (GLOBAL indices) so we can re-apply anytime
all_haptic_groups_global = []

batslist = []
# NEW: collect (backbone_site_1based, core_atom_index_1based) for every filled site
backbone_core_pairs = [] # NEW

iteration = 0
for ligand in ligand_list:
donor_indices = list(ligand[1]) # may include lists (haptics) or ints
Expand Down Expand Up @@ -235,6 +240,7 @@
denticity = len(donor_groups)

structure = get_next_structure(metals_structures_copy, denticity)

if manual and iteration < len(manual_list):
valid_subsets = [manual_list[iteration]]
else:
Expand All @@ -244,6 +250,7 @@
print(f"Valid subsets: {valid_subsets}")
print(structure)

prev = structure['occupied_mask'].copy()
# group-aware Kabsch
best_subset, best_aligned_coords, best_rmsd, placement_attempts, best_perm_idx = clash_aware_kabsch(
ligand_all_coords,
Expand Down Expand Up @@ -276,6 +283,13 @@
assert best_subset != None, "Check total number of coordination sites in tested geometry, as well as total number of coordinating atoms"
structure['occupied_mask'][np.array(best_subset)] = True

curr = np.array(structure['occupied_mask'].copy())
# Find where a value flipped from False → True
new_true = np.where((~prev) & curr)[0] + 1 # +1 for 1-based indexing
batslist.append(list(new_true))
# Update stored state
prev = curr.copy()

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'prev' is unnecessary as it is
redefined
before this value is used.

# group→site mapping (not strictly needed for bonding)
donor_to_site, site_indices_in_group_order = map_donors_to_sites_haptic_aware(
donor_groups=donor_groups,
Expand Down Expand Up @@ -330,6 +344,14 @@
)
# ---------------------------------------------------------------------------------

# NEW: record which core3D atom now occupies each backbone site filled by this ligand
# Use the representative donor atom for each donor group (donor_reps).
# Both backbone site indices and core atom indices reported as 1-based.
for site_idx_0b, rep_local_0b in zip(site_indices_in_group_order, donor_reps): # both 0-based
core_idx_1b = int(local2global[rep_local_0b]) + 1 # assume 0-based mapping → make 1-based
backbone_core_pairs.append((int(site_idx_0b) + 1, core_idx_1b)) # (backbone_site_1b, core_atom_1b)
# END NEW

# clean bonds & optimize
core3D.convert2OBMol(force_clean=True)
replace_bonds(core3D.OBMol, core3D.bo_dict)
Expand Down Expand Up @@ -423,8 +445,14 @@
# get sterics report
if run_sterics:
clashes, severity, fig = run_sterics_check(core3D, max_steps, ff_name)
else:
fig = None # just to be explicit

# NEW: produce the final ordered list of core indices per filled backbone site
backbone_core_indices = [core for (site, core) in sorted(backbone_core_pairs, key=lambda x: x[0])] # NEW

return core3D, clashes, severity, fig
# OLD + NEW: append backbone_core_indices at the end
return core3D, clashes, severity, fig, batslist, backbone_core_indices

def run_sterics_check(core3D, max_steps, ff_name):
optimized_coords, per_atom_ff_force = constrained_forcefield_optimization(
Expand Down Expand Up @@ -691,3 +719,174 @@
print(f"[ok] Wrote complex to {out_path} (format={out_fmt})")
except Exception as e:
print(f"[warn] Could not write file ({e}). Returning object only.")


def enhanced_init_ANN(metal, ox, spin, ligands, occs, dents,
batslist, tcats, licores, geometry):
"""Initializes ANN.
Parameters
----------
args : Namespace
Namespace of arguments.
ligands : list
List of ligands, given as names.
occs : list
List of ligand occupations (frequencies of each ligand).
dents : list
List of ligand denticities.
batslist : list
List of backbond points.
tcats : list
List of SMILES ligand connecting atoms.
licores : dict
Ligand dictionary within molSimplify.
Returns
-------
ANN_flag : bool
Whether an ANN call was successful.
ANN_bondl : float
ANN predicted bond length.
ANN_reason : str
Reason for ANN failure, if failed.
ANN_attributes : dict
Dictionary of predicted attributes of complex.
catalysis_flag : bool
Whether or not complex is compatible for catalytic ANNs.
"""
# initialize ANN
globs = globalvars()

Check notice

Code scanning / CodeQL

Unused local variable Note

Variable globs is not used.
catalysis_flag = False

# new RACs-ANN
from molSimplify.Scripts.tf_nn_prep import tf_ANN_preproc
# Set default value [] in case decoration is not used
decoration_index = []

ANN_flag, ANN_reason, ANN_attributes, catalysis_flag = tf_ANN_preproc(
metal, ox, spin, ligands, occs, dents, batslist,
tcats, licores, False, decoration_index, False,
geometry, debug=False)

if ANN_flag:
ANN_bondl = ANN_attributes['ANN_bondl']
else:
# there needs to be 1 length per possible lig
ANN_bondl = len(
[item for items in batslist for item in items])*[False]
return ANN_flag, ANN_bondl, ANN_reason, ANN_attributes, catalysis_flag

import numpy as np

def enforce_metal_ligand_distances_and_optimize(
core3D,
bondl,
backbone_core_indices, # 1-based indices
*,
ff_name: str = "UFF",
max_steps: int = 500,
constrain: bool = True,
tolerate_zero_vec: float = 1e-6,
):
"""
Adjust metal–ligand distances and perform a constrained FF optimization.
Assumes backbone_core_indices are 1-based. If bondl is None, per-site targets
are set to (metal covalent radius + donor covalent radius) in Å.
"""

# --- helper: robust scalar cast (accepts scalar or length-1 array/list) ---
def _as_float(x):
if isinstance(x, (list, tuple, np.ndarray)):
arr = np.asarray(x, dtype=float).reshape(-1)
if arr.size != 1:
raise TypeError(f"Each bond length must be a scalar; got shape {arr.shape} with values {arr}.")
return float(arr[0])
return float(x)

# 1) Convert to 0-based indices for internal use
donor_idxs = [int(i) - 1 for i in backbone_core_indices]

# 2) Coords & elements
coords, elements = get_all_coords_and_elements(core3D)
coords = np.asarray(coords, dtype=float)
N = len(coords)

# 3) Identify metal centers
try:
metal_indices = core3D.findMetal(transition_metals_only=True)
except Exception:
tm_set = {
"Sc","Ti","V","Cr","Mn","Fe","Co","Ni","Cu","Zn",
"Y","Zr","Nb","Mo","Tc","Ru","Rh","Pd","Ag","Cd",
"La","Hf","Ta","W","Re","Os","Ir","Pt","Au","Hg"
}
metal_indices = [i for i, sym in enumerate(elements) if sym in tm_set]

if not metal_indices:
raise RuntimeError("No metal centers found in core3D.")

metal_indices = sorted(set(int(i) for i in metal_indices))
metal_coords = coords[metal_indices]

# 4) Quick bond lookup (0-based pairs)
bonded_pairs = {(min(a, b), max(a, b)) for (a, b) in core3D.bo_dict}

def donor_metal_for(idx):
"""Prefer the metal it's bonded to; else nearest metal geometrically."""
for m in metal_indices:
if (min(idx, m), max(idx, m)) in bonded_pairs:
return m
diffs = metal_coords - coords[idx]
return metal_indices[int(np.argmin(np.sum(diffs**2, axis=1)))]

# --- NEW: if bondl is None, build it from covalent radii sums ---
if bondl is None:
bondl = []
for donor_idx in donor_idxs:
if donor_idx < 0 or donor_idx >= N:
raise IndexError(f"Donor index {donor_idx+1} out of bounds (1-based input).")
m_idx = donor_metal_for(donor_idx)
r_m = float(core3D.getAtom(m_idx).rad) # Å
r_l = float(core3D.getAtom(donor_idx).rad) # Å
bondl.append(r_m + r_l)
else:
if len(bondl) != len(donor_idxs):
raise ValueError("bondl and backbone_core_indices must have the same length")

# 5) Move each donor along its M–L line to the target distance
new_coords = coords.copy()
for target_dist_raw, donor_idx in zip(bondl, donor_idxs):
if donor_idx < 0 or donor_idx >= N:
raise IndexError(f"Donor index {donor_idx+1} out of bounds (1-based input).")

target_dist = _as_float(target_dist_raw)

m_idx = donor_metal_for(donor_idx)
M = coords[m_idx]
L = coords[donor_idx]

vec = L - M
dist = float(np.linalg.norm(vec))
if dist < tolerate_zero_vec:
direction = np.array([0.0, 0.0, 1.0], dtype=float)
else:
direction = vec / dist

new_coords[donor_idx] = M + direction * target_dist

core3D = set_new_coords(core3D, new_coords)

# 6) Constrained FF optimization (freeze metals + donors)
if constrain:
frozen = sorted(set(metal_indices + donor_idxs))
optimized = constrained_forcefield_optimization(
core3D,
frozen,
max_steps=max_steps,
ff_name=ff_name,
)
core3D = set_new_coords(core3D, optimized)

return core3D
24 changes: 20 additions & 4 deletions molSimplify/Scripts/tf_nn_prep.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,10 +713,18 @@
print(
'warning, ANN predicts a near degenerate ground state for this complex')
print(f"ANN predicts a spin splitting (HS - LS) of {float(split[0]):.2f} kcal/mol at {100 * alpha:.0f}% HFX")
print('ANN low spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join(
[f"{float(i):.2f}" for i in r_ls[0]]) + ' angstrom')
print('ANN high spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join(
[f"{float(i):.2f}" for i in r_hs[0]]) + ' angstrom')

try:
print('ANN low spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join(
[f"{float(i):.2f}" for i in r_ls[0]]) + ' angstrom')
print('ANN high spin bond length (ax1/ax2/eq) is predicted to be: ' + " /".join(
[f"{float(i):.2f}" for i in r_hs[0]]) + ' angstrom')
except:

Check notice

Code scanning / CodeQL

Except block handles 'BaseException' Note

Except block directly handles BaseException.
vals = _flat_floats(r_ls)
print('ANN low spin bond length (ax1/ax2/eq) is predicted to be: ' + ', '.join(f"{v:.2f}" for v in vals) + ' angstrom')
vals = _flat_floats(r_hs)
print('ANN high spin bond length (ax1/ax2/eq) is predicted to be: ' + ', '.join(f"{v:.2f}" for v in vals) + ' angstrom')

print(f'distance to splitting energy training data is {split_dist:.2f}')
print(splitting_ANN_trust_message)
print()
Expand All @@ -734,6 +742,14 @@
return ANN_attributes


def _flat_floats(x):
import numpy as np
arr = np.asarray(x, dtype=float)
if arr.ndim > 1:
arr = arr[0]
return arr.ravel()


def evaluate_catalytic_anns(this_complex: mol3D, metal: str, ox: int, spin: int,
custom_ligand_dict: Dict[str, list],
net_lig_charge: int, exchange: Union[str, float, int],
Expand Down
Loading
Loading