Skip to content

Commit

Permalink
Typed tensors (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
dschwen committed Mar 10, 2025
1 parent b86a47c commit 0696d3d
Show file tree
Hide file tree
Showing 13 changed files with 307 additions and 240 deletions.
76 changes: 65 additions & 11 deletions include/problems/TensorProblem.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@
#include "DomainInterface.h"
#include "SwiftTypes.h"
#include "SwiftUtils.h"
#include "TensorBuffer.h"

#include "AuxiliarySystem.h"
#include "libmesh/petsc_vector.h"

#include <memory>
#include <torch/torch.h>

class UniformTensorMesh;
Expand Down Expand Up @@ -48,7 +50,9 @@ class TensorProblem : public FEProblem, public DomainInterface
// recompute quantities on grid size change
virtual void gridChanged();

virtual void addTensorBuffer(const std::string & buffer_name, InputParameters & parameters);
virtual void addTensorBuffer(const std::string & buffer_type,
const std::string & buffer_name,
InputParameters & parameters);

virtual void addTensorComputeInitialize(const std::string & compute_name,
const std::string & name,
Expand All @@ -70,12 +74,17 @@ class TensorProblem : public FEProblem, public DomainInterface
const std::string & name,
InputParameters & parameters);

torch::Tensor & getBuffer(const std::string & buffer_name);
const std::vector<torch::Tensor> & getBufferOld(const std::string & buffer_name,
unsigned int max_states);
/// returns teh current state of the tensor
template <typename T = torch::Tensor>
T & getBuffer(const std::string & buffer_name);

/// return the old states of the tensor
template <typename T = torch::Tensor>
const std::vector<T> & getBufferOld(const std::string & buffer_name, unsigned int max_states);

/// returns a reference to a copy of buffer_name that is guaranteed to be contiguous and located on the CPU device
const torch::Tensor & getCPUBuffer(const std::string & buffer_name);
template <typename T = torch::Tensor>
const T & getCPUBuffer(const std::string & buffer_name);

TensorOperatorBase & getOnDemandCompute(const std::string & name);

Expand Down Expand Up @@ -123,6 +132,10 @@ class TensorProblem : public FEProblem, public DomainInterface
/// perform output tasks
void executeTensorOutputs(const ExecFlagType & exec_type);

/// helper to get the TensorBuffer wrapper object that holds the actual tensor data
template <typename T = torch::Tensor>
TensorBuffer<T> & getBufferHelper(const std::string & buffer_name);

/// tensor options
const torch::TensorOptions _options;

Expand All @@ -140,13 +153,10 @@ class TensorProblem : public FEProblem, public DomainInterface
Real _output_time;

/// list of TensorBuffers (i.e. tensors)
std::map<std::string, torch::Tensor> _tensor_buffer;
std::map<std::string, std::shared_ptr<TensorBufferBase>> _tensor_buffer;

/// list of read-only CPU TensorBuffers (for MOOSE objects and outputs)
std::map<std::string, torch::Tensor> _tensor_cpu_buffer;

/// old buffers (stores max number of states, requested, and states)
std::map<std::string, std::pair<unsigned int, std::vector<torch::Tensor>>> _old_tensor_buffer;
/// set of tensors that need to be copied to the CPU
std::set<std::string> _cpu_tensor_buffers;

/// old timesteps
std::vector<Real> _old_dt;
Expand Down Expand Up @@ -208,3 +218,47 @@ TensorProblem::getSolver() const
}
mooseError("No TensorSolver has been set up.");
}

template <typename T>
TensorBuffer<T> &
TensorProblem::getBufferHelper(const std::string & buffer_name)
{
auto it = _tensor_buffer.find(buffer_name);
if (it == _tensor_buffer.end())
mooseError("TensorBuffer '", buffer_name, "' does not exist in the system.");
auto tensor_buffer = dynamic_cast<TensorBuffer<T> *>(it->second.get());
if (!tensor_buffer)
mooseError("TensorBuffer '",
buffer_name,
"' of the requested type '",
it->second->type(),
"' does not exist in the system.");
return *tensor_buffer;
}

template <typename T>
T &
TensorProblem::getBuffer(const std::string & buffer_name)
{
return getBufferHelper<T>(buffer_name)._u;
}

template <typename T>
const std::vector<T> &
TensorProblem::getBufferOld(const std::string & buffer_name, unsigned int max_states)
{
auto & tensor_buffer = getBufferHelper<T>(buffer_name);

if (tensor_buffer._max_states < max_states)
tensor_buffer._max_states = max_states;

return tensor_buffer._u_old;
}

template <typename T>
const T &
TensorProblem::getCPUBuffer(const std::string & buffer_name)
{
_cpu_tensor_buffers.insert(buffer_name);
return getBufferHelper<T>(buffer_name)._u_cpu;
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,27 @@

#pragma once

#include "TensorBufferBase.h"
#ifdef NEML2_ENABLED

#include "TensorBuffer.h"
#include "neml2/tensors/Vec.h"
#include "neml2/tensors/SR2.h"

/**
* Symmetric rank two valued Tensor
* Tensor wrapper arbitrary tensor value dimensions
*/
class SR2TensorBuffer : public TensorBufferBase
template <typename T>
class NEML2TensorBuffer : public TensorBuffer<T>
{
public:
static InputParameters validParams();

SR2TensorBuffer(const InputParameters & parameters);
NEML2TensorBuffer(const InputParameters & parameters);

// NEML2::Scalar getNEML2() { return NEML2::Scalar(*this, _domain_shape, _value_shape); }
virtual void init();
};

using VectorTensor = NEML2TensorBuffer<neml2::Vec>;
using SR2Tensor = NEML2TensorBuffer<neml2::SR2>;

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@

#pragma once

#include "TensorBufferBase.h"
#include "TensorBuffer.h"

/**
* Tensor wrapper arbitrary tensor value dimensions
*/
class ScalarTensorBuffer : public TensorBufferBase
class PlainTensorBuffer : public TensorBuffer<torch::Tensor>
{
public:
static InputParameters validParams();

ScalarTensorBuffer(const InputParameters & parameters);
PlainTensorBuffer(const InputParameters & parameters);

virtual void init();
};
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,27 @@
#include "TensorBufferBase.h"

/**
* Vector valued Tensor
* Tensor wrapper arbitrary tensor value dimensions
*/
class VectorTensorBuffer : public TensorBufferBase
template <typename T>
class TensorBuffer : public TensorBufferBase
{
public:
static InputParameters validParams();

VectorTensorBuffer(const InputParameters & parameters);
TensorBuffer(const InputParameters & parameters);

virtual std::size_t advanceState();
virtual void clearStates();
virtual void makeCPUCopy();

/// current state of the tensor
T _u;

/// potential CPU copy of the tensor (if requested)
T _u_cpu;

/// old states of the tensor
std::vector<T> _u_old;
std::size_t _max_states;
};
28 changes: 16 additions & 12 deletions include/tensor_buffers/TensorBufferBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,35 @@
/**
* Tensor wrapper arbitrary tensor value dimensions
*/
class TensorBufferBase : public torch::Tensor, public MooseObject, public DomainInterface
class TensorBufferBase : public MooseObject, public DomainInterface
{
public:
static InputParameters validParams();

TensorBufferBase(const InputParameters & parameters);

/// assignment operator
TensorBufferBase& operator=(const torch::Tensor& rhs);
TensorBufferBase & operator=(const torch::Tensor & rhs);

protected:
const bool _reciprocal;
/// advance state, returns the new number of old states
virtual std::size_t advanceState() = 0;

/// clear old states
virtual void clearStates() = 0;

/// create a contiguous CPU copy of the current tensor
virtual void makeCPUCopy() = 0;

/// initialize the tensor
virtual void init() = 0;

protected:
/// expand the tensor to full dimensions
void expand();

const torch::IntArrayRef _domain_shape;

const std::vector<int64_t> _value_shape_buffer;
const torch::IntArrayRef _value_shape;
const bool _reciprocal;

const std::vector<int64_t> _shape_buffer;
torch::IntArrayRef _shape;
const torch::IntArrayRef _domain_shape;

const torch::TensorOptions _options;

using torch::Tensor::expand;
};
6 changes: 3 additions & 3 deletions src/actions/AddTensorBufferAction.C
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ InputParameters
AddTensorBufferAction::validParams()
{
InputParameters params = MooseObjectAction::validParams();
params.addClassDescription("Add an TensorBuffer object to the simulation.");
params.set<std::string>("type") = "ScalarTensorBuffer";
params.addClassDescription("Add a TensorBuffer object to the simulation.");
params.set<std::string>("type") = "PlainTensorBuffer";
return params;
}

Expand All @@ -32,5 +32,5 @@ AddTensorBufferAction::act()
if (!tensor_problem)
mooseError("Tensor Buffers are only supported if the problem class is set to `TensorProblem`");

tensor_problem->addTensorBuffer(_name, _moose_object_pars);
tensor_problem->addTensorBuffer(_type, _name, _moose_object_pars);
}
Loading

0 comments on commit 0696d3d

Please sign in to comment.