Enzyme main
Loading...
Searching...
No Matches
HMCUtils.cpp File Reference
#include "HMCUtils.h"
#include "Dialect/Impulse/Impulse.h"
#include "Dialect/Ops.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Math/IR/Math.h"
#include <cmath>
#include <limits>
Include dependency graph for HMCUtils.cpp:

Go to the source code of this file.

Functions

static Value createIdentityMatrix (OpBuilder &builder, Location loc, RankedTensorType matrixType)
 Creates a 2D identity matrix of the specified type.
 
static Value createPermutationMatrix (OpBuilder &builder, Location loc, RankedTensorType matrixType)
 Creates a permutation matrix of size n x n.
 
static Value reverseRowsAndColumns (OpBuilder &builder, Location loc, Value matrix)
 Computes A[::-1, ::-1] using permutation matrix through P @ A @ P.
 
static Value scatterPositionToTrace (OpBuilder &builder, Location loc, Value position2d, Value fullTrace, const HMCContext &ctx)
 
static Value gatherPositionFromTrace (OpBuilder &builder, Location loc, Value fullTrace, const HMCContext &ctx)
 

Function Documentation

◆ createIdentityMatrix()

static Value createIdentityMatrix ( OpBuilder & builder,
Location loc,
RankedTensorType matrixType )
static

Creates a 2D identity matrix of the specified type.

Definition at line 75 of file HMCUtils.cpp.

Referenced by mlir::impulse::computeMassMatrixSqrt(), and mlir::impulse::finalizeWelford().

◆ createPermutationMatrix()

static Value createPermutationMatrix ( OpBuilder & builder,
Location loc,
RankedTensorType matrixType )
static

Creates a permutation matrix of size n x n.

Definition at line 98 of file HMCUtils.cpp.

Referenced by reverseRowsAndColumns().

◆ gatherPositionFromTrace()

static Value gatherPositionFromTrace ( OpBuilder & builder,
Location loc,
Value fullTrace,
const HMCContext & ctx )
static

◆ reverseRowsAndColumns()

static Value reverseRowsAndColumns ( OpBuilder & builder,
Location loc,
Value matrix )
static

Computes A[::-1, ::-1] using permutation matrix through P @ A @ P.

Definition at line 121 of file HMCUtils.cpp.

References createPermutationMatrix().

Referenced by mlir::impulse::computeMassMatrixSqrt().

◆ scatterPositionToTrace()

static Value scatterPositionToTrace ( OpBuilder & builder,
Location loc,
Value position2d,
Value fullTrace,
const HMCContext & ctx )
static