Gromacs
2026.0-dev-20241213-9ac17bb
|
#include <gromacs/applied_forces/nnpot/torchmodel.h>
Class responsible for loading and evaluating a TorchScript-compiled neural network model. Inherits from NNPotModel.
Public Member Functions | |
TorchModel (const std::string &filename, const MDLogger *logger) | |
Constructor for TorchModel. More... | |
void | initModel () override |
Initialize the neural network model. | |
void | evaluateModel () override |
Call inference on NN model. | |
void | getOutputs (std::vector< int > &indices, gmx_enerdata_t &enerd, const ArrayRef< RVec > &forces, bool provideForces) override |
Retrieve NN model outputs. | |
void | setCommRec (const t_commrec *cr) override |
Set communication record for possible communication of input/output data between ranks. | |
void | setDevice () |
determine which device to use depending on GMX_NN_DEVICE environment variable | |
void | loadModelExtensions (std::string &extension_libs) |
load custom extensions used to compile the TorchScript model from extra_files | |
void | prepareAtomPositions (std::vector< RVec > &positions) override |
Functions to prepare inputs for NN model. More... | |
void | prepareAtomNumbers (std::vector< int > &atomTypes) override |
void | prepareBox (matrix &box) override |
void | preparePbcType (PbcType &pbcType) override |
gmx::TorchModel::TorchModel | ( | const std::string & | filename, |
const MDLogger * | logger | ||
) |
Constructor for TorchModel.
[in] | filename | path to the TorchScript model file |
[in] | logger | pointer to the MDLogger |
|
overridevirtual |
Functions to prepare inputs for NN model.
Create input torch::Tensors for the model.
Implements gmx::INNPotModel.