690 PatternRewriter &rewriter)
const override {
692 if (uop.getOutputs().size() == 0)
695 auto inpActivity = uop.getActivity();
696 auto retActivity = uop.getRetActivity();
698 SmallVector<mlir::Value, 2> in_args;
699 SmallVector<mlir::Value, 2> outs_args;
700 SmallVector<Type, 2> in_ty;
701 SmallVector<Type, 2> out_ty;
702 SmallVector<ActivityAttr, 2> newInActivityArgs;
703 SmallVector<ActivityAttr, 2> newRetActivityArgs;
705 bool changed =
false;
709 for (
auto [idx, act] : llvm::enumerate(inpActivity)) {
710 auto iattr = cast<ActivityAttr>(act);
711 auto val = iattr.getValue();
712 mlir::Value res = uop.getInputs()[in_idx];
713 in_args.push_back(res);
714 in_ty.push_back(res.getType());
717 if (val == Activity::enzyme_dup || val == Activity::enzyme_dupnoneed) {
718 mlir::Value dres = uop.getInputs()[in_idx];
719 in_args.push_back(dres);
720 in_ty.push_back(dres.getType());
725 if (in_idx == uop.getInputs().size())
729 for (
auto [idx, act] : llvm::enumerate(retActivity)) {
730 auto iattr = cast<ActivityAttr>(act);
731 auto val = iattr.getValue();
734 if (val == Activity::enzyme_constnoneed ||
735 val == Activity::enzyme_dupnoneed) {
736 newRetActivityArgs.push_back(iattr);
740 mlir::Value res = uop.getOutputs()[out_idx];
743 case Activity::enzyme_active: {
748 mlir::Value dres = uop.getInputs()[in_idx];
751 auto dres_type = dres.getType();
752 auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
754 if (!res.use_empty()) {
755 outs_args.push_back(res);
756 out_ty.push_back(res.getType());
757 ActivityAttr new_act = iattr;
758 if (dres_type_intf && !
isMutable(dres_type) &&
759 dres_type_intf.isZero(dres)) {
762 new_act = ActivityAttr::get(rewriter.getContext(),
763 Activity::enzyme_const);
765 in_args.push_back(dres);
766 in_ty.push_back(dres_type);
768 newRetActivityArgs.push_back(new_act);
771 ActivityAttr new_act = ActivityAttr::get(
772 rewriter.getContext(), Activity::enzyme_activenoneed);
773 if (dres_type_intf && !
isMutable(dres_type) &&
774 dres_type_intf.isZero(dres)) {
776 new_act = ActivityAttr::get(rewriter.getContext(),
777 Activity::enzyme_constnoneed);
780 in_args.push_back(dres);
781 in_ty.push_back(dres_type);
783 newRetActivityArgs.push_back(new_act);
790 case Activity::enzyme_activenoneed:
793 mlir::Value dres = uop.getInputs()[in_idx];
795 auto new_act = iattr;
797 auto dres_type = dres.getType();
798 auto dres_type_intf = dyn_cast<AutoDiffTypeInterface>(dres_type);
799 if (dres_type_intf && !
isMutable(dres_type) &&
800 dres_type_intf.isZero(dres)) {
802 new_act = ActivityAttr::get(rewriter.getContext(),
803 Activity::enzyme_constnoneed);
805 in_args.push_back(dres);
806 in_ty.push_back(dres_type);
808 newRetActivityArgs.push_back(iattr);
811 case Activity::enzyme_const:
814 auto new_act = iattr;
815 if (!res.use_empty()) {
816 outs_args.push_back(res);
817 out_ty.push_back(res.getType());
818 newRetActivityArgs.push_back(new_act);
821 new_act = ActivityAttr::get(rewriter.getContext(),
822 Activity::enzyme_constnoneed);
823 newRetActivityArgs.push_back(new_act);
829 case Activity::enzyme_dup:
832 outs_args.push_back(res);
833 out_ty.push_back(res.getType());
834 newRetActivityArgs.push_back(iattr);
838 case Activity::enzyme_constnoneed:
839 case Activity::enzyme_dupnoneed:
843 llvm_unreachable(
"unexpected activity arg");
848 for (
auto [idx, act] : llvm::enumerate(inpActivity)) {
849 auto iattr = cast<ActivityAttr>(act);
850 auto val = iattr.getValue();
852 if (val == Activity::enzyme_active) {
853 mlir::Value res = uop.getOutputs()[out_idx];
854 if (!res.use_empty()) {
855 out_ty.push_back(res.getType());
856 outs_args.push_back(res);
857 newInActivityArgs.push_back(iattr);
862 auto new_const = ActivityAttr::get(rewriter.getContext(),
863 Activity::enzyme_const);
864 newInActivityArgs.push_back(new_const);
867 out_ty.push_back(res.getType());
868 outs_args.push_back(res);
869 newInActivityArgs.push_back(iattr);
874 }
else if (val == Activity::enzyme_activenoneed) {
875 mlir::Value res = uop.getOutputs()[out_idx];
876 out_ty.push_back(res.getType());
877 outs_args.push_back(res);
878 newInActivityArgs.push_back(iattr);
880 llvm_unreachable(
"unsupported arg activenoneed");
882 newInActivityArgs.push_back(iattr);
889 ArrayAttr newInActivity =
890 ArrayAttr::get(rewriter.getContext(),
891 llvm::ArrayRef<Attribute>(newInActivityArgs.begin(),
892 newInActivityArgs.end()));
893 ArrayAttr newRetActivity =
894 ArrayAttr::get(rewriter.getContext(),
895 llvm::ArrayRef<Attribute>(newRetActivityArgs.begin(),
896 newRetActivityArgs.end()));
898 SourceOp newOp = SourceOpCreator::create(rewriter, uop, out_ty, in_args,
899 newInActivity, newRetActivity);
904 for (
auto [idx, old_act, new_act] :
905 llvm::enumerate(retActivity, newRetActivityArgs)) {
906 auto iattr = cast<ActivityAttr>(old_act);
907 auto old_val = iattr.getValue();
908 auto new_val = new_act.getValue();
910 if (old_val == new_val) {
912 if (old_val == Activity::enzyme_constnoneed ||
913 old_val == Activity::enzyme_activenoneed ||
914 old_val == Activity::enzyme_dupnoneed) {
918 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
919 newOp.getOutputs()[newIdx++]);
922 if (new_val == Activity::enzyme_activenoneed &&
923 old_val == Activity::enzyme_active) {
925 }
else if (new_val == Activity::enzyme_constnoneed &&
926 old_val == Activity::enzyme_const) {
928 }
else if (old_val == Activity::enzyme_active &&
929 new_val == Activity::enzyme_const) {
930 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
931 newOp.getOutputs()[newIdx++]);
932 }
else if (old_val == Activity::enzyme_active &&
933 new_val == Activity::enzyme_constnoneed) {
935 }
else if (old_val == Activity::enzyme_activenoneed &&
936 new_val == Activity::enzyme_constnoneed) {
942 for (
auto [idx, old_act, new_act] :
943 llvm::enumerate(inpActivity, newInActivityArgs)) {
944 auto iattr = cast<ActivityAttr>(old_act);
945 auto old_val = iattr.getValue();
946 auto new_val = new_act.getValue();
948 if (old_val == new_val) {
949 if (old_val == Activity::enzyme_active ||
950 old_val == Activity::enzyme_activenoneed) {
951 uop.getOutputs()[oldIdx++].replaceAllUsesWith(
952 newOp.getOutputs()[newIdx++]);
957 if (old_val == Activity::enzyme_active &&
958 new_val == Activity::enzyme_const) {
963 rewriter.eraseOp(uop);