848 FunctionOpInterface callee,
849 const SmallPtrSet<Operation *, 2> &returnOps,
850 SymbolTableCollection *symbolTable,
851 bool verbose,
bool annotate) {
852 auto isActiveData = [&](Value value) {
856 bool backwardActive = bva && bva->getValue().isActiveVal();
857 return forwardActive && backwardActive;
860 auto isConstantValue = [&](Value value) {
862 if (isa<LLVM::LLVMPointerType, MemRefType>(value.getType())) {
863 assert(returnOps.size() == 1);
865 solver.getProgramPointAfter(*returnOps.begin()));
867 solver.getProgramPointBefore(
868 &callee.getFunctionBody().front().front()));
872 solver.getProgramPointAfter(*returnOps.begin()));
873 auto *aliasClassLattice = solver.lookupState<AliasClassLattice>(value);
875 std::deque<DistinctAttr> frontier;
876 DenseSet<DistinctAttr> visited;
878 (void)aliasClasses.foreachElement(
881 "unhandled undefined/unknown case before visit");
882 if (!visited.contains(neighbor)) {
883 visited.insert(neighbor);
884 frontier.push_back(neighbor);
886 return ChangeResult::NoChange;
893 assert(!aliasClassLattice->isUndefined() &&
894 "didn't compute alias classes");
896 if (aliasClassLattice->isUnknown()) {
902 scheduleVisit(aliasClassLattice->getAliasClassesObject());
904 while (!frontier.empty()) {
905 DistinctAttr aliasClass = frontier.front();
906 frontier.pop_front();
910 if (fma->hasActiveData(aliasClass) &&
911 bma->activeDataFlowsOut(aliasClass))
916 assert(!pointsToSets->getPointsTo(aliasClass).isUndefined() &&
917 "couldn't compute points-to sets");
920 if (pointsToSets->getPointsTo(aliasClass).isUnknown())
923 scheduleVisit(pointsToSets->getPointsTo(aliasClass));
929 return !isActiveData(value);
932 std::function<bool(Operation *)> isConstantInstruction = [&](Operation *op) {
936 return llvm::none_of(op->getResults(), isActiveData);
940 if (
auto storeOp = dyn_cast<LLVM::StoreOp>(op)) {
941 if (!isConstantValue(storeOp.getValue())) {
944 }
else if (
auto callOp = dyn_cast<CallOpInterface>(op)) {
947 auto callable = cast<CallableOpInterface>(callOp.resolveCallable());
948 if (callable.getCallableRegion()) {
951 WalkResult result = callable->walk([&](Operation *op) {
952 if (!isConstantInstruction(op)) {
953 return WalkResult::interrupt();
955 return WalkResult::advance();
957 return !result.wasInterrupted();
962 return llvm::none_of(op->getOperands(), isActiveData) &&
963 llvm::none_of(op->getResults(), isActiveData);
966 errs() << FlatSymbolRefAttr::get(callee) <<
":\n";
967 for (BlockArgument arg : callee.getArguments()) {
968 if (Attribute tagAttr =
969 callee.getArgAttr(arg.getArgNumber(),
"enzyme.tag")) {
970 errs() <<
" " << tagAttr <<
": "
971 << (isConstantValue(arg) ?
"Constant" :
"Active") <<
"\n";
976 MLIRContext *ctx = callee.getContext();
978 func.walk([&](Operation *op) {
980 SmallVector<bool> argICVs(func.getNumArguments());
981 llvm::transform(func.getArguments(), argICVs.begin(),
983 func->setAttr(
"enzyme.icv", DenseBoolArrayAttr::get(ctx, argICVs));
987 op->setAttr(
"enzyme.ici",
988 BoolAttr::get(ctx, isConstantInstruction(op)));
991 if (op->getNumResults() == 0) {
993 }
else if (op->getNumResults() == 1) {
994 icv = isConstantValue(op->getResult(0));
997 "annotating icv for op that produces multiple results");
1000 op->setAttr(
"enzyme.icv", BoolAttr::get(ctx, icv));
1004 callee.walk([&](Operation *op) {
1005 if (op->hasAttr(
"tag")) {
1006 errs() <<
" " << op->getAttr(
"tag") <<
": ";
1007 for (OpResult opResult : op->getResults()) {
1008 errs() << (isConstantValue(opResult) ?
"Constant" :
"Active") <<
"\n";
1013 for (OpResult result : op->getResults()) {
1014 auto forwardValueActivity =
1016 if (forwardValueActivity) {
1017 std::string dest, key{
"fva"};
1018 llvm::raw_string_ostream os(dest);
1019 if (op->getNumResults() != 1)
1020 key += result.getResultNumber();
1021 forwardValueActivity->getValue().print(os);
1022 op->setAttr(key, StringAttr::get(op->getContext(), dest));
1025 auto backwardValueActivity =
1027 if (backwardValueActivity) {
1028 std::string dest, key{
"bva"};
1029 llvm::raw_string_ostream os(dest);
1030 if (op->getNumResults() != 1)
1031 key += result.getResultNumber();
1032 backwardValueActivity->getValue().print(os);
1033 op->setAttr(key, StringAttr::get(op->getContext(), dest));
1041 for (BlockArgument arg : callee.getArguments()) {
1042 auto backwardValueActivity =
1044 if (backwardValueActivity) {
1046 llvm::raw_string_ostream os(dest);
1047 backwardValueActivity->getValue().print(os);
1048 callee.setArgAttr(arg.getArgNumber(),
"enzyme.bva",
1049 StringAttr::get(callee->getContext(), dest));
1053 for (Operation *returnOp : returnOps) {
1055 solver.getProgramPointAfter(returnOp));
1057 errs() <<
"forward end state:\n" << *state <<
"\n";
1059 errs() <<
"state was null\n";
1063 solver.getProgramPointAfter(&callee.getFunctionBody().front().front()));
1065 errs() <<
"backwards end state:\n" << *startState <<
"\n";
1067 errs() <<
"backwards end state was null\n";
1072 FunctionOpInterface callee, ArrayRef<enzyme::Activity> argumentActivity,
1073 bool print,
bool verbose,
bool annotate) {
1074 SymbolTableCollection symbolTable;
1075 DataFlowSolver solver;
1087 solver.load<DeadCodeAnalysis>();
1088 solver.load<SparseConstantPropagation>();
1091 for (
const auto &[arg, activity] :
1092 llvm::zip(callee.getArguments(), argumentActivity)) {
1095 if (activity == enzyme::Activity::enzyme_active) {
1103 SmallPtrSet<Operation *, 2> returnOps;
1104 for (Operation &op : callee.getFunctionBody().getOps()) {
1105 if (op.hasTrait<OpTrait::ReturnLike>()) {
1106 returnOps.insert(&op);
1107 for (Value operand : op.getOperands()) {
1108 auto *returnLattice =
1111 if (isa<FloatType, ComplexType>(operand.getType())) {
1118 if (failed(solver.initializeAndRun(callee->getParentOfType<ModuleOp>()))) {
1119 assert(
false &&
"dataflow analysis failed\n");