Skip to content

Commit

Permalink
Init branch
Browse files Browse the repository at this point in the history
  • Loading branch information
Chengqian-Zhang committed Mar 10, 2025
1 parent 228062c commit 975d9e4
Show file tree
Hide file tree
Showing 7 changed files with 1,055 additions and 156 deletions.
289 changes: 206 additions & 83 deletions deepmd/pt/loss/denoise.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
import numpy as np
import torch
import torch.nn.functional as F

Expand All @@ -8,102 +9,224 @@
from deepmd.pt.utils import (
env,
)
from deepmd.pt.utils.env import (
GLOBAL_PT_FLOAT_PRECISION,
)
from deepmd.utils.data import (
DataRequirementItem,
)
from deepmd.pt.utils.region import (
phys2inter,
inter2phys,
)

def get_cell_perturb_matrix(cell_pert_fraction: float):
# TODO: user fix some component
if cell_pert_fraction < 0:
raise RuntimeError("cell_pert_fraction can not be negative")
e0 = torch.rand(6)
e = e0 * 2 * cell_pert_fraction - cell_pert_fraction
cell_pert_matrix = torch.tensor(
[
[1 + e[0], 0, 0],
[e[5], 1 + e[1], 0],
[e[4], e[3], 1 + e[2]],
],
dtype=env.GLOBAL_PT_FLOAT_PRECISION,
device=env.DEVICE
)
return cell_pert_matrix, e

class DenoiseLoss(TaskLoss):
def __init__(
self,
ntypes,
masked_token_loss=1.0,
masked_coord_loss=1.0,
norm_loss=0.01,
use_l1=True,
beta=1.00,
mask_loss_coord=True,
mask_loss_token=True,
mask_token: bool = False,
mask_coord: bool = True,
mask_cell: bool = False,
token_loss: float = 1.0,
coord_loss: float = 1.0,
cell_loss: float = 1.0,
noise_type: str = "gaussian",
coord_noise: float = 0.2,
cell_pert_fraction: float = 0.0,
noise_mode: str = "prob",
mask_num: int = 1,
mask_prob: float = 0.2,
loss_func: str = "rmse",
**kwargs,
) -> None:
"""Construct a layer to compute loss on coord, and type reconstruction."""
r"""Construct a layer to compute loss on token, coord and cell.
Parameters
----------
mask_token: bool
Whether to mask token.
mask_coord: bool
Whether to mask coordinate.
mask_cell: bool
Whether to mask cell.
token_loss: float
The preference factor for token denoise.
coord_loss: float
The preference factor for coordinate denoise.
cell_loss: float
The preference factor for cell denoise.
noise_type : str
The type of noise to add to the coordinate. It can be 'uniform' or 'gaussian'.
coord_noise : float
The magnitude of noise to add to the coordinate.
cell_pert_fraction: float
A value determines how much will cell deform.
noise_mode : str
"'prob' means the noise is added with a probability.'fix_num' means the noise is added with a fixed number."
mask_num : int
The number of atoms to mask coordinates. It is only used when noise_mode is 'fix_num'.
mask_prob : float
The probability of masking coordinates. It is only used when noise_mode is 'prob'.
loss_func: str
The loss function to minimize, it can be 'mae' or 'rmse'.
**kwargs
Other keyword arguments.
"""
super().__init__()
self.ntypes = ntypes
self.masked_token_loss = masked_token_loss
self.masked_coord_loss = masked_coord_loss
self.norm_loss = norm_loss
self.has_coord = self.masked_coord_loss > 0.0
self.has_token = self.masked_token_loss > 0.0
self.has_norm = self.norm_loss > 0.0
self.use_l1 = use_l1
self.beta = beta
self.frac_beta = 1.00 / self.beta
self.mask_loss_coord = mask_loss_coord
self.mask_loss_token = mask_loss_token

def forward(self, model_pred, label, natoms, learning_rate, mae=False):
"""Return loss on coord and type denoise.
self.mask_token = mask_token
self.mask_coord = mask_coord
self.mask_cell = mask_cell
self.token_loss = token_loss
self.coord_loss = coord_loss
self.cell_loss = cell_loss
self.noise_type = noise_type
self.coord_noise = coord_noise
self.cell_pert_fraction = cell_pert_fraction
self.noise_mode = noise_mode
self.mask_num = mask_num
self.mask_prob = mask_prob
self.loss_func = loss_func

def forward(self, input_dict, model, label, natoms, learning_rate, mae=False):
"""Return loss on token,coord and cell.
Parameters
----------
input_dict : dict[str, torch.Tensor]
Model inputs.
model : torch.nn.Module
Model to be used to output the predictions.
label : dict[str, torch.Tensor]
Labels.
natoms : int
The local atom number.
Returns
-------
- loss: Loss to minimize.
model_pred: dict[str, torch.Tensor]
Model predictions.
loss: torch.Tensor
Loss for model to minimize.
more_loss: dict[str, torch.Tensor]
Other losses for display.
"""
updated_coord = model_pred["updated_coord"]
logits = model_pred["logits"]
clean_coord = label["clean_coord"]
clean_type = label["clean_type"]
coord_mask = label["coord_mask"]
type_mask = label["type_mask"]

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if self.has_coord:
if self.mask_loss_coord:
masked_updated_coord = updated_coord[coord_mask]
masked_clean_coord = clean_coord[coord_mask]
if masked_updated_coord.size(0) > 0:
coord_loss = F.smooth_l1_loss(
masked_updated_coord.view(-1, 3),
masked_clean_coord.view(-1, 3),
reduction="mean",
beta=self.beta,
)
else:
coord_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
nloc = input_dict["atype"].shape[1]
nbz = input_dict["atype"].shape[0]
input_dict["box"] = input_dict["box"].cuda()

# TODO: Change lattice to lower triangular matrix

label["clean_coord"] = input_dict["coord"].clone().detach()
label["clean_box"] = input_dict["box"].clone().detach()
origin_frac_coord = phys2inter(label["clean_coord"], label["clean_box"].reshape(nbz,3,3))
label["clean_frac_coord"] = origin_frac_coord.clone().detach()
if self.mask_cell:
strain_components_all = torch.zeros((nbz,3), dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
for ii in range(nbz):
cell_perturb_matrix, strain_components = get_cell_perturb_matrix_HEA(self.cell_noise)
# left-multiplied by `cell_perturb_matrix`` to get the noise box
input_dict["box"][ii] = torch.matmul(cell_perturb_matrix, input_dict["box"][ii].reshape(3,3)).reshape(-1)
input_dict["coord"][ii] = torch.matmul(origin_frac_coord[ii].reshape(nloc,3), input_dict["box"][ii].reshape(3,3))
strain_components_all[ii] = strain_components.reshape(-1)
label["strain_components"] = strain_components_all.clone().detach()

if self.mask_coord:
# add noise to coordinates and update label['updated_coord']
mask_num = 0
if self.noise_mode == "fix_num":
mask_num = self.mask_num
if(nloc < mask_num):
mask_num = nloc
elif self.noise_mode == "prob":
mask_num = int(self.mask_prob * nloc)
if mask_num == 0:
mask_num = 1
else:
coord_loss = F.smooth_l1_loss(
updated_coord.view(-1, 3),
clean_coord.view(-1, 3),
reduction="mean",
beta=self.beta,
)
loss += self.masked_coord_loss * coord_loss
more_loss["coord_l1_error"] = coord_loss.detach()
if self.has_token:
if self.mask_loss_token:
masked_logits = logits[type_mask]
masked_target = clean_type[type_mask]
if masked_logits.size(0) > 0:
token_loss = F.nll_loss(
F.log_softmax(masked_logits, dim=-1),
masked_target,
reduction="mean",
NotImplementedError(f"Unknown noise mode {self.noise_mode}!")

coord_mask_all = torch.zeros(input_dict["atype"].shape, dtype=torch.bool, device=env.DEVICE)
for ii in range(nbz):
noise_on_coord = 0.0
coord_mask_res = np.random.choice(range(nloc), mask_num, replace=False).tolist()
coord_mask = np.isin(range(nloc), coord_mask_res) # nloc
if self.noise_type == "uniform":
noise_on_coord = np.random.uniform(
low=-self.noise, high=self.noise, size=(mask_num, 3)
)
elif self.noise_type == "gaussian":
noise_on_coord = np.random.normal(
loc=0.0, scale=self.noise, size=(mask_num, 3)
)
else:
token_loss = torch.zeros(
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
else:
token_loss = F.nll_loss(
F.log_softmax(logits.view(-1, self.ntypes - 1), dim=-1),
clean_type.view(-1),
reduction="mean",
)
loss += self.masked_token_loss * token_loss
more_loss["token_error"] = token_loss.detach()
if self.has_norm:
norm_x = model_pred["norm_x"]
norm_delta_pair_rep = model_pred["norm_delta_pair_rep"]
loss += self.norm_loss * (norm_x + norm_delta_pair_rep)
more_loss["norm_loss"] = norm_x.detach() + norm_delta_pair_rep.detach()

return loss, more_loss
raise NotImplementedError(f"Unknown noise type {self.noise_type}!")

noise_on_coord = torch.tensor(noise_on_coord, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE) # mask_num 3
input_dict["coord"][ii][coord_mask ,:] += noise_on_coord # nbz mask_num 3 //
coord_mask_all[ii] = torch.tensor(coord_mask, dtype=torch.bool, device=env.DEVICE)
label['coord_mask'] = coord_mask_all
frac_coord = phys2inter(input_dict["coord"], input_dict["box"].reshape(nbz,3,3))
#label["updated_coord"] = (label["clean_frac_coord"] - frac_coord).clone().detach()
label["updated_coord"] = ((label["clean_frac_coord"] - frac_coord) @ label["clean_box"].reshape(nbz,3,3)).clone().detach()

if self.mask_token:
# TODO: mask_token
pass

if (not self.mask_coord) and (not self.mask_cell) and (not self.mask_token):
raise RuntimeError("At least one of mask_coord, mask_cell and mask_token should be True!")

model_pred = model(**input_dict)

loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}

diff_coord = (label["updated_coord"] - model_pred["updated_coord"]).reshape(-1)
diff_cell = (label["strain_components"] - model_pred["strain_components"]).reshape(-1)
if self.loss_func == "rmse":
l2_coord_loss = torch.mean(torch.square(diff_coord))
l2_cell_loss = torch.mean(torch.square(diff_cell))
rmse_f = l2_coord_loss.sqrt()
rmse_v = l2_cell_loss.sqrt()
more_loss["rmse_coord"] = rmse_f.detach()
more_loss["rmse_cell"] = rmse_v.detach()
loss += self.coord_loss * l2_coord_loss.to(GLOBAL_PT_FLOAT_PRECISION) + self.cell_loss * l2_cell_loss.to(GLOBAL_PT_FLOAT_PRECISION)
elif self.loss_func == "mae":
l1_coord_loss = F.l1_loss(label["updated_coord"], model_pred["updated_coord"], reduction="none")
l1_cell_loss = F.l1_loss(label["strain_components"], model_pred["strain_components"], reduction="none")
more_loss["mae_coord"] = l1_coord_loss.mean().detach()
more_loss["mae_cell"] = l1_cell_loss.mean().detach()
l1_coord_loss = l1_coord_loss.sum(-1).mean(-1).sum()
l1_cell_loss = l1_cell_loss.sum()
loss += self.coord_loss * l1_coord_loss.to(GLOBAL_PT_FLOAT_PRECISION) + self.cell_loss * l1_cell_loss.to(GLOBAL_PT_FLOAT_PRECISION)
else:
raise RuntimeError(f"Unknown loss function {self.loss_func}!")
return model_pred, loss, more_loss

@property
def label_requirement(self) -> list[DataRequirementItem]:
"""Return data label requirements needed for this loss calculation."""
return []

def serialize(self) -> dict:
pass

@classmethod
def deserialize(cls, data: dict) -> "TaskLoss":
pass
4 changes: 4 additions & 0 deletions deepmd/pt/model/atomic_model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
from .property_atomic_model import (
DPPropertyAtomicModel,
)
from .denoise_atomic_model import (
DPDenoiseAtomicModel
)

__all__ = [
"BaseAtomicModel",
Expand All @@ -54,4 +57,5 @@
"DPZBLLinearEnergyAtomicModel",
"LinearEnergyAtomicModel",
"PairTabAtomicModel",
"DPDenoiseAtomicModel",
]
61 changes: 61 additions & 0 deletions deepmd/pt/model/atomic_model/denoise_atomic_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-License-Identifier: LGPL-3.0-or-later

import torch
import logging

from deepmd.pt.model.task.denoise import (
DenoiseNet,
)

from .dp_atomic_model import (
DPAtomicModel,
)
from IPython import embed

log = logging.getLogger(__name__)

class DPDenoiseAtomicModel(DPAtomicModel):
def __init__(self, descriptor, fitting, type_map, **kwargs):
if not isinstance(fitting, DenoiseNet):
raise TypeError(
"fitting must be an instance of DenoiseNet for DPDenoiseAtomicModel"
)
super().__init__(descriptor, fitting, type_map, **kwargs)

def apply_out_stat(
self,
ret: dict[str, torch.Tensor],
atype: torch.Tensor,
):
# hack !!!
ret["virial"] = ret["virial"]/240
ret["force"] = ret["force"]/29

'''
virial = ret["virial"] # 原始形状 [nbz, nloc, 6]
# 批量处理所有元素(保留梯度)
# 重塑为二维张量以便处理 [batch_size * nloc, 9]
virial_2d = virial.view(-1, 6)
# 构建3x3对称矩阵(向量化操作)
# 每个元素的索引对应原始矩阵位置:
# [0, 1, 2] 为对角线元素
# [3, 4, 5] 对应下三角元素(自动保持对称性)
matrices = torch.zeros(virial_2d.size(0), 3, 3,
dtype=virial.dtype, device=virial.device)
# 填充对角线元素
matrices[:, 0, 0] = 1 + virial_2d[:, 0]
matrices[:, 1, 1] = 1 + virial_2d[:, 1]
matrices[:, 2, 2] = 1 + virial_2d[:, 2]
# 填充对称的非对角线元素
matrices[:, 0, 1] = matrices[:, 1, 0] = 0.5 * virial_2d[:, 5] # (0,1) & (1,0)
matrices[:, 0, 2] = matrices[:, 2, 0] = 0.5 * virial_2d[:, 4] # (0,2) & (2,0)
matrices[:, 1, 2] = matrices[:, 2, 1] = 0.5 * virial_2d[:, 3] # (1,2) & (2,1)
# 恢复原始形状 [nbz, nloc, 3, 3] -> [nbz, nloc, 9]
ret["virial"] = matrices.view(virial.shape[0], virial.shape[1], 9)
'''
return ret
Loading

0 comments on commit 975d9e4

Please sign in to comment.