Skip to content

[flang][cuda] Move cuf.set_allocator_idx after derived-type init #148936

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 15, 2025

Conversation

clementval
Copy link
Contributor

Derived type initialization overwrite the component descriptor. Place the cuf.set_allocator_idx after the initialization is performed.

@clementval clementval requested a review from wangzpgi July 15, 2025 19:21
@llvmbot llvmbot added flang Flang issues not falling into any other category flang:fir-hlfir labels Jul 15, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 15, 2025

@llvm/pr-subscribers-flang-fir-hlfir

Author: Valentin Clement (バレンタイン クレメン) (clementval)

Changes

Derived type initialization overwrite the component descriptor. Place the cuf.set_allocator_idx after the initialization is performed.


Full diff: https://github.com/llvm/llvm-project/pull/148936.diff

2 Files Affected:

  • (modified) flang/lib/Lower/ConvertVariable.cpp (+94-73)
  • (modified) flang/test/Lower/CUDA/cuda-set-allocator.cuf (+11-3)
diff --git a/flang/lib/Lower/ConvertVariable.cpp b/flang/lib/Lower/ConvertVariable.cpp
index ffe456de56630..bfd4782c764fc 100644
--- a/flang/lib/Lower/ConvertVariable.cpp
+++ b/flang/lib/Lower/ConvertVariable.cpp
@@ -771,79 +771,9 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
       return builder.create<cuf::SharedMemoryOp>(loc, ty, nm, symNm, lenParams,
                                                  indices);
 
-    if (!cuf::isCUDADeviceContext(builder.getRegion())) {
-      mlir::Value alloc = builder.create<cuf::AllocOp>(
-          loc, ty, nm, symNm, dataAttr, lenParams, indices);
-      if (const auto *details{
-              ultimateSymbol
-                  .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
-        const Fortran::semantics::DeclTypeSpec *type{details->type()};
-        const Fortran::semantics::DerivedTypeSpec *derived{
-            type ? type->AsDerived() : nullptr};
-        if (derived) {
-          Fortran::semantics::UltimateComponentIterator components{*derived};
-          auto recTy = mlir::dyn_cast<fir::RecordType>(ty);
-
-          llvm::SmallVector<mlir::Value> coordinates;
-          for (const auto &sym : components) {
-            if (Fortran::semantics::IsDeviceAllocatable(sym)) {
-              unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
-              mlir::Type fieldTy;
-              std::vector<mlir::Value> coordinates;
-
-              if (fieldIdx != std::numeric_limits<unsigned>::max()) {
-                // Field found in the base record type.
-                auto fieldName = recTy.getTypeList()[fieldIdx].first;
-                fieldTy = recTy.getTypeList()[fieldIdx].second;
-                mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
-                    loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
-                    recTy,
-                    /*typeParams=*/mlir::ValueRange{});
-                coordinates.push_back(fieldIndex);
-              } else {
-                // Field not found in base record type, search in potential
-                // record type components.
-                for (auto component : recTy.getTypeList()) {
-                  if (auto childRecTy =
-                          mlir::dyn_cast<fir::RecordType>(component.second)) {
-                    fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
-                    if (fieldIdx != std::numeric_limits<unsigned>::max()) {
-                      mlir::Value parentFieldIndex =
-                          builder.create<fir::FieldIndexOp>(
-                              loc, fir::FieldType::get(childRecTy.getContext()),
-                              component.first, recTy,
-                              /*typeParams=*/mlir::ValueRange{});
-                      coordinates.push_back(parentFieldIndex);
-                      auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
-                      fieldTy = childRecTy.getTypeList()[fieldIdx].second;
-                      mlir::Value childFieldIndex =
-                          builder.create<fir::FieldIndexOp>(
-                              loc, fir::FieldType::get(fieldTy.getContext()),
-                              fieldName, childRecTy,
-                              /*typeParams=*/mlir::ValueRange{});
-                      coordinates.push_back(childFieldIndex);
-                      break;
-                    }
-                  }
-                }
-              }
-
-              if (coordinates.empty())
-                TODO(loc, "device resident component in complex derived-type "
-                          "hierarchy");
-
-              mlir::Value comp = builder.create<fir::CoordinateOp>(
-                  loc, builder.getRefType(fieldTy), alloc, coordinates);
-              cuf::DataAttributeAttr dataAttr =
-                  Fortran::lower::translateSymbolCUFDataAttribute(
-                      builder.getContext(), sym);
-              builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
-            }
-          }
-        }
-      }
-      return alloc;
-    }
+    if (!cuf::isCUDADeviceContext(builder.getRegion()))
+      return builder.create<cuf::AllocOp>(loc, ty, nm, symNm, dataAttr,
+                                          lenParams, indices);
   }
 
   // Let the builder do all the heavy lifting.
@@ -857,6 +787,91 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
   return res;
 }
 
+/// Device allocatable component in a derived-type don't have the correct
+/// allocator index in their descriptor when they are created. After
+/// initialization, cuf.set_allocator_idx operations are inserted to set the
+/// correct allocator index for each device component.
+static void
+initializeDeviceComponentAllocator(Fortran::lower::AbstractConverter &converter,
+                                   const Fortran::semantics::Symbol &symbol,
+                                   Fortran::lower::SymMap &symMap) {
+  if (const auto *details{
+          symbol.GetUltimate()
+              .detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
+    const Fortran::semantics::DeclTypeSpec *type{details->type()};
+    const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
+                                                            : nullptr};
+    if (derived) {
+      fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+      mlir::Location loc = converter.getCurrentLocation();
+
+      fir::ExtendedValue exv =
+          converter.getSymbolExtendedValue(symbol.GetUltimate(), &symMap);
+      auto recTy = mlir::dyn_cast<fir::RecordType>(
+          fir::unwrapRefType(fir::getBase(exv).getType()));
+      assert(recTy && "expected fir::RecordType");
+
+      llvm::SmallVector<mlir::Value> coordinates;
+      Fortran::semantics::UltimateComponentIterator components{*derived};
+      for (const auto &sym : components) {
+        if (Fortran::semantics::IsDeviceAllocatable(sym)) {
+          unsigned fieldIdx = recTy.getFieldIndex(sym.name().ToString());
+          mlir::Type fieldTy;
+          std::vector<mlir::Value> coordinates;
+
+          if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+            // Field found in the base record type.
+            auto fieldName = recTy.getTypeList()[fieldIdx].first;
+            fieldTy = recTy.getTypeList()[fieldIdx].second;
+            mlir::Value fieldIndex = builder.create<fir::FieldIndexOp>(
+                loc, fir::FieldType::get(fieldTy.getContext()), fieldName,
+                recTy,
+                /*typeParams=*/mlir::ValueRange{});
+            coordinates.push_back(fieldIndex);
+          } else {
+            // Field not found in base record type, search in potential
+            // record type components.
+            for (auto component : recTy.getTypeList()) {
+              if (auto childRecTy =
+                      mlir::dyn_cast<fir::RecordType>(component.second)) {
+                fieldIdx = childRecTy.getFieldIndex(sym.name().ToString());
+                if (fieldIdx != std::numeric_limits<unsigned>::max()) {
+                  mlir::Value parentFieldIndex =
+                      builder.create<fir::FieldIndexOp>(
+                          loc, fir::FieldType::get(childRecTy.getContext()),
+                          component.first, recTy,
+                          /*typeParams=*/mlir::ValueRange{});
+                  coordinates.push_back(parentFieldIndex);
+                  auto fieldName = childRecTy.getTypeList()[fieldIdx].first;
+                  fieldTy = childRecTy.getTypeList()[fieldIdx].second;
+                  mlir::Value childFieldIndex =
+                      builder.create<fir::FieldIndexOp>(
+                          loc, fir::FieldType::get(fieldTy.getContext()),
+                          fieldName, childRecTy,
+                          /*typeParams=*/mlir::ValueRange{});
+                  coordinates.push_back(childFieldIndex);
+                  break;
+                }
+              }
+            }
+          }
+
+          if (coordinates.empty())
+            TODO(loc, "device resident component in complex derived-type "
+                      "hierarchy");
+
+          mlir::Value comp = builder.create<fir::CoordinateOp>(
+              loc, builder.getRefType(fieldTy), fir::getBase(exv), coordinates);
+          cuf::DataAttributeAttr dataAttr =
+              Fortran::lower::translateSymbolCUFDataAttribute(
+                  builder.getContext(), sym);
+          builder.create<cuf::SetAllocatorIndexOp>(loc, comp, dataAttr);
+        }
+      }
+    }
+  }
+}
+
 /// Must \p var be default initialized at runtime when entering its scope.
 static bool
 mustBeDefaultInitializedAtRuntime(const Fortran::lower::pft::Variable &var) {
@@ -1179,6 +1194,9 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
   if (mustBeDefaultInitializedAtRuntime(var))
     Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
                                                symMap);
+  if (converter.getFoldingContext().languageFeatures().IsEnabled(
+          Fortran::common::LanguageFeature::CUDA))
+    initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
   auto *builder = &converter.getFirOpBuilder();
   if (needCUDAAlloc(var.getSymbol()) &&
       !cuf::isCUDADeviceContext(builder->getRegion())) {
@@ -1437,6 +1455,9 @@ static void instantiateAlias(Fortran::lower::AbstractConverter &converter,
   if (mustBeDefaultInitializedAtRuntime(var))
     Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
                                                symMap);
+  if (converter.getFoldingContext().languageFeatures().IsEnabled(
+          Fortran::common::LanguageFeature::CUDA))
+    initializeDeviceComponentAllocator(converter, var.getSymbol(), symMap);
 }
 
 //===--------------------------------------------------------------===//
diff --git a/flang/test/Lower/CUDA/cuda-set-allocator.cuf b/flang/test/Lower/CUDA/cuda-set-allocator.cuf
index bf74e012a639d..85715edf26942 100644
--- a/flang/test/Lower/CUDA/cuda-set-allocator.cuf
+++ b/flang/test/Lower/CUDA/cuda-set-allocator.cuf
@@ -12,10 +12,18 @@ contains
   end subroutine
 
 ! CHECK-LABEL: func.func @_QMm1Psub1()
-! CHECK: %[[DT:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
-! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]], x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
+! CHECK: %[[DT:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
+! CHECK: fir.address_of(@_QQ_QMm1Tty_device.DerivedInit)
+! CHECK: fir.copy 
+! CHECK: %[[X:.*]] = fir.coordinate_of %[[DT]]#0, x : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
 ! CHECK: cuf.set_allocator_idx %[[X]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
-! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]], z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
+! CHECK: %[[Z:.*]] = fir.coordinate_of %[[DT]]#0, z : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>,y:i32,z:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>>
 ! CHECK: cuf.set_allocator_idx %[[Z]] : !fir.ref<!fir.box<!fir.heap<!fir.array<?xi32>>>> {data_attr = #cuf.cuda<device>}
 
 end module
+
+
+
+    
+  

@clementval clementval merged commit b4e2272 into llvm:main Jul 15, 2025
7 of 8 checks passed
@clementval clementval deleted the cuf_set_alloc_after_init branch July 15, 2025 20:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
flang:fir-hlfir flang Flang issues not falling into any other category
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants