From 95bd47cb99ba4ab0c27adf721f639ab43565fb45 Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Fri, 28 Nov 2025 14:10:00 +0800 Subject: [PATCH 1/3] fix: Ensure proper tensor conversion for numpy solver in Eigenvalues class --- dptb/nn/energy.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/dptb/nn/energy.py b/dptb/nn/energy.py index 4ea03324..6ada2609 100644 --- a/dptb/nn/energy.py +++ b/dptb/nn/energy.py @@ -90,16 +90,19 @@ def forward(self, chklowtinv = torch.linalg.inv(chklowt) data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj()) elif eig_solver == 'numpy': - chklowt = np.linalg.cholesky(data[self.s_out_field].detach().numpy()) + chklowt = np.linalg.cholesky(data[self.s_out_field].detach().cpu().numpy()) chklowtinv = np.linalg.inv(chklowt) - data[self.h_out_field] = (chklowtinv @ data[self.h_out_field].detach().numpy() @ np.transpose(chklowtinv,(0,2,1)).conj()) - else: - data[self.h_out_field] = data[self.h_out_field] - + data[self.h_out_field] = (chklowtinv @ data[self.h_out_field].detach().cpu().numpy() @ np.transpose(chklowtinv,(0,2,1)).conj()) + elif eig_solver == 'numpy': + # Convert to numpy when using numpy solver without overlap + data[self.h_out_field] = data[self.h_out_field].detach().cpu().numpy() + if eig_solver == 'torch': eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field])) elif eig_solver == 'numpy': - eigvals.append(torch.from_numpy(np.linalg.eigvalsh(a=data[self.h_out_field]))) + eigvals_np = np.linalg.eigvalsh(a=data[self.h_out_field]) + # Preserve dtype by converting to the Hamiltonian's original dtype + eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype)) data[self.out_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)]) if nested: From c6f87e6e00e4aa1b96ad7eb271507421672bcd4c Mon Sep 17 00:00:00 2001 From: AsymmetryChou <181240085@smail.nju.edu.cn> Date: Fri, 28 Nov 2025 14:26:49 +0800 Subject: [PATCH 2/3] use local variable to ensure dtype correctness in data.dict --- dptb/nn/energy.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/dptb/nn/energy.py b/dptb/nn/energy.py index 6ada2609..6dea5395 100644 --- a/dptb/nn/energy.py +++ b/dptb/nn/energy.py @@ -90,19 +90,18 @@ def forward(self, chklowtinv = torch.linalg.inv(chklowt) data[self.h_out_field] = (chklowtinv @ data[self.h_out_field] @ torch.transpose(chklowtinv,dim0=1,dim1=2).conj()) elif eig_solver == 'numpy': - chklowt = np.linalg.cholesky(data[self.s_out_field].detach().cpu().numpy()) + s_np = data[self.s_out_field].detach().cpu().numpy() + h_np = data[self.h_out_field].detach().cpu().numpy() + chklowt = np.linalg.cholesky(s_np) chklowtinv = np.linalg.inv(chklowt) - data[self.h_out_field] = (chklowtinv @ data[self.h_out_field].detach().cpu().numpy() @ np.transpose(chklowtinv,(0,2,1)).conj()) - elif eig_solver == 'numpy': - # Convert to numpy when using numpy solver without overlap - data[self.h_out_field] = data[self.h_out_field].detach().cpu().numpy() + h_transformed_np = chklowtinv @ h_np @ np.transpose(chklowtinv,(0,2,1)).conj() if eig_solver == 'torch': eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field])) elif eig_solver == 'numpy': - eigvals_np = np.linalg.eigvalsh(a=data[self.h_out_field]) + eigvals_np = np.linalg.eigvalsh(a=h_transformed_np) # Preserve dtype by converting to the Hamiltonian's original dtype - eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype)) + eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype, device=self.h2k.device)) data[self.out_field] = torch.nested.as_nested_tensor([torch.cat(eigvals, dim=0)]) if nested: From 52d59285379b2f0dc7baa425bd101c0e9409a905 Mon Sep 17 00:00:00 2001 From: YiTian Yang <79531875+Lonya0@users.noreply.github.com> Date: Sun, 30 Nov 2025 00:29:23 +0800 Subject: [PATCH 3/3] fallback for h_transformed_np when overlap is False fallback for h_transformed_np when overlap is False --- dptb/nn/energy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/dptb/nn/energy.py b/dptb/nn/energy.py index 6dea5395..6103bad0 100644 --- a/dptb/nn/energy.py +++ b/dptb/nn/energy.py @@ -83,6 +83,7 @@ def forward(self, for i in range(int(np.ceil(num_k / nk))): data[AtomicDataDict.KPOINT_KEY] = kpoints[i*nk:(i+1)*nk] data = self.h2k(data) + h_transformed_np = None if self.overlap: data = self.s2k(data) if eig_solver == 'torch': @@ -99,6 +100,8 @@ def forward(self, if eig_solver == 'torch': eigvals.append(torch.linalg.eigvalsh(data[self.h_out_field])) elif eig_solver == 'numpy': + if h_transformed_np is None: + h_transformed_np = data[self.h_out_field].detach().cpu().numpy() eigvals_np = np.linalg.eigvalsh(a=h_transformed_np) # Preserve dtype by converting to the Hamiltonian's original dtype eigvals.append(torch.from_numpy(eigvals_np).to(dtype=self.h2k.dtype, device=self.h2k.device))