Gromacs  2026.0-dev-20241213-9ac17bb
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Groups Pages
List of all members | Public Member Functions
gmx::TorchModel Class Reference

#include <gromacs/applied_forces/nnpot/torchmodel.h>

+ Inheritance diagram for gmx::TorchModel:
+ Collaboration diagram for gmx::TorchModel:

Description

Class responsible for loading and evaluating a TorchScript-compiled neural network model. Inherits from NNPotModel.

Public Member Functions

 TorchModel (const std::string &filename, const MDLogger *logger)
 Constructor for TorchModel. More...
 
void initModel () override
 Initialize the neural network model.
 
void evaluateModel () override
 Call inference on NN model.
 
void getOutputs (std::vector< int > &indices, gmx_enerdata_t &enerd, const ArrayRef< RVec > &forces, bool provideForces) override
 Retrieve NN model outputs.
 
void setCommRec (const t_commrec *cr) override
 Set communication record for possible communication of input/output data between ranks.
 
void setDevice ()
 determine which device to use depending on GMX_NN_DEVICE environment variable
 
void loadModelExtensions (std::string &extension_libs)
 load custom extensions used to compile the TorchScript model from extra_files
 
void prepareAtomPositions (std::vector< RVec > &positions) override
 Functions to prepare inputs for NN model. More...
 
void prepareAtomNumbers (std::vector< int > &atomTypes) override
 
void prepareBox (matrix &box) override
 
void preparePbcType (PbcType &pbcType) override
 

Constructor & Destructor Documentation

gmx::TorchModel::TorchModel ( const std::string &  filename,
const MDLogger logger 
)

Constructor for TorchModel.

Parameters
[in]filenamepath to the TorchScript model file
[in]loggerpointer to the MDLogger

Member Function Documentation

void gmx::TorchModel::prepareAtomPositions ( std::vector< RVec > &  positions)
overridevirtual

Functions to prepare inputs for NN model.

Create input torch::Tensors for the model.

Implements gmx::INNPotModel.


The documentation for this class was generated from the following files: