Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SPV_AMDX_shader_enqueue version 2 support #5838

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions DEPS
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ vars = {

're2_revision': '6dcd83d60f7944926bfd308cc13979fc53dd69ca',

'spirv_headers_revision': 'd92cf88c371424591115a87499009dfad41b669c',
'spirv_headers_revision': '07ddb1c0f1ffa929262d4568481a692bb0fb1535',
}

deps = {
Expand All @@ -37,4 +37,3 @@ deps = {
Var('github') + '/KhronosGroup/SPIRV-Headers.git@' +
Var('spirv_headers_revision'),
}

6 changes: 6 additions & 0 deletions source/name_mapper.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -241,6 +243,10 @@ spv_result_t FriendlyNameMapper::ParseInstruction(
SaveName(result_id,
std::string("_runtimearr_") + NameForId(inst.words[2]));
break;
case spv::Op::OpTypeNodePayloadArrayAMDX:
SaveName(result_id,
std::string("_payloadarr_") + NameForId(inst.words[2]));
break;
case spv::Op::OpTypePointer:
SaveName(result_id, std::string("_ptr_") +
NameForEnumOperand(SPV_OPERAND_TYPE_STORAGE_CLASS,
Expand Down
9 changes: 7 additions & 2 deletions source/opcode.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// Copyright (c) 2015-2022 The Khronos Group Inc.
// Modifications Copyright (C) 2020 Advanced Micro Devices, Inc. All rights
// reserved.
// Modifications Copyright (C) 2020-2024 Advanced Micro Devices, Inc. All
// rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -265,12 +265,14 @@ int32_t spvOpcodeIsConstant(const spv::Op opcode) {
case spv::Op::OpConstantSampler:
case spv::Op::OpConstantNull:
case spv::Op::OpConstantFunctionPointerINTEL:
case spv::Op::OpConstantStringAMDX:
case spv::Op::OpSpecConstantTrue:
case spv::Op::OpSpecConstantFalse:
case spv::Op::OpSpecConstant:
case spv::Op::OpSpecConstantComposite:
case spv::Op::OpSpecConstantCompositeReplicateEXT:
case spv::Op::OpSpecConstantOp:
case spv::Op::OpSpecConstantStringAMDX:
return true;
default:
return false;
Expand Down Expand Up @@ -318,6 +320,7 @@ bool spvOpcodeReturnsLogicalVariablePointer(const spv::Op opcode) {
case spv::Op::OpFunctionParameter:
case spv::Op::OpImageTexelPointer:
case spv::Op::OpCopyObject:
case spv::Op::OpAllocateNodePayloadsAMDX:
case spv::Op::OpSelect:
case spv::Op::OpPhi:
case spv::Op::OpFunctionCall:
Expand All @@ -344,6 +347,7 @@ int32_t spvOpcodeReturnsLogicalPointer(const spv::Op opcode) {
case spv::Op::OpImageTexelPointer:
case spv::Op::OpCopyObject:
case spv::Op::OpRawAccessChainNV:
case spv::Op::OpAllocateNodePayloadsAMDX:
return true;
default:
return false;
Expand Down Expand Up @@ -382,6 +386,7 @@ int32_t spvOpcodeGeneratesType(spv::Op op) {
case spv::Op::OpTypeRayQueryKHR:
case spv::Op::OpTypeHitObjectNV:
case spv::Op::OpTypeUntypedPointerKHR:
case spv::Op::OpTypeNodePayloadArrayAMDX:
return true;
default:
// In particular, OpTypeForwardPointer does not generate a type,
Expand Down
4 changes: 4 additions & 0 deletions source/opt/fix_storage_class.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2019 Google LLC
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -99,6 +101,7 @@ bool FixStorageClass::PropagateStorageClass(Instruction* inst,
case spv::Op::OpCopyMemorySized:
case spv::Op::OpVariable:
case spv::Op::OpBitcast:
case spv::Op::OpAllocateNodePayloadsAMDX:
// Nothing to change for these opcode. The result type is the same
// regardless of the storage class of the operand.
return false;
Expand Down Expand Up @@ -319,6 +322,7 @@ uint32_t FixStorageClass::WalkAccessChainType(Instruction* inst, uint32_t id) {
switch (type_inst->opcode()) {
case spv::Op::OpTypeArray:
case spv::Op::OpTypeRuntimeArray:
case spv::Op::OpTypeNodePayloadArrayAMDX:
case spv::Op::OpTypeMatrix:
case spv::Op::OpTypeVector:
case spv::Op::OpTypeCooperativeMatrixKHR:
Expand Down
3 changes: 3 additions & 0 deletions source/opt/ir_context.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2017 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -539,6 +541,7 @@ void IRContext::AddCombinatorsForCapability(uint32_t capability) {
(uint32_t)spv::Op::OpTypeHitObjectNV,
(uint32_t)spv::Op::OpTypeArray,
(uint32_t)spv::Op::OpTypeRuntimeArray,
(uint32_t)spv::Op::OpTypeNodePayloadArrayAMDX,
(uint32_t)spv::Op::OpTypeStruct,
(uint32_t)spv::Op::OpTypeOpaque,
(uint32_t)spv::Op::OpTypePointer,
Expand Down
5 changes: 4 additions & 1 deletion source/opt/local_access_chain_convert_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2017 The Khronos Group Inc.
// Copyright (c) 2017 Valve Corporation
// Copyright (c) 2017 LunarG Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -430,7 +432,8 @@ void LocalAccessChainConvertPass::InitExtensions() {
"SPV_NV_bindless_texture", "SPV_EXT_shader_atomic_float_add",
"SPV_EXT_fragment_shader_interlock",
"SPV_KHR_compute_shader_derivatives", "SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch"});
"SPV_KHR_cooperative_matrix", "SPV_KHR_ray_tracing_position_fetch",
"SPV_AMDX_shader_enqueue"});
}

bool LocalAccessChainConvertPass::AnyIndexIsOutOfBounds(
Expand Down
3 changes: 3 additions & 0 deletions source/opt/local_single_block_elim_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2017 The Khronos Group Inc.
// Copyright (c) 2017 Valve Corporation
// Copyright (c) 2017 LunarG Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -238,6 +240,7 @@ void LocalSingleBlockLoadStoreElimPass::InitExtensions() {
"SPV_AMD_gcn_shader",
"SPV_KHR_shader_ballot",
"SPV_AMD_shader_ballot",
"SPV_AMDX_shader_enqueue",
"SPV_AMD_gpu_shader_half_float",
"SPV_KHR_shader_draw_parameters",
"SPV_KHR_subgroup_vote",
Expand Down
5 changes: 4 additions & 1 deletion source/opt/local_single_store_elim_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
// Copyright (c) 2017 The Khronos Group Inc.
// Copyright (c) 2017 Valve Corporation
// Copyright (c) 2017 LunarG Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -144,7 +146,8 @@ void LocalSingleStoreElimPass::InitExtensionAllowList() {
"SPV_KHR_compute_shader_derivatives",
"SPV_NV_cooperative_matrix",
"SPV_KHR_cooperative_matrix",
"SPV_KHR_ray_tracing_position_fetch"});
"SPV_KHR_ray_tracing_position_fetch",
"SPV_AMDX_shader_enqueue"});
}
bool LocalSingleStoreElimPass::ProcessVariable(Instruction* var_inst) {
std::vector<Instruction*> users;
Expand Down
5 changes: 4 additions & 1 deletion source/opt/scalar_replacement_pass.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2017 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -671,7 +673,8 @@ bool ScalarReplacementPass::CheckTypeAnnotations(
for (auto inst :
get_decoration_mgr()->GetDecorationsFor(typeInst->result_id(), false)) {
uint32_t decoration;
if (inst->opcode() == spv::Op::OpDecorate) {
if (inst->opcode() == spv::Op::OpDecorate ||
inst->opcode() == spv::Op::OpDecorateId) {
decoration = inst->GetSingleWordInOperand(1u);
} else {
assert(inst->opcode() == spv::Op::OpMemberDecorate);
Expand Down
31 changes: 30 additions & 1 deletion source/opt/type_manager.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -335,6 +337,17 @@ uint32_t TypeManager::GetTypeInstruction(const Type* type) {
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {subtype}}});
break;
}
case Type::kNodePayloadArrayAMDX: {
uint32_t subtype =
GetTypeInstruction(type->AsNodePayloadArrayAMDX()->element_type());
if (subtype == 0) {
return 0;
}
typeInst = MakeUnique<Instruction>(
context(), spv::Op::OpTypeNodePayloadArrayAMDX, 0, id,
std::initializer_list<Operand>{{SPV_OPERAND_TYPE_ID, {subtype}}});
break;
}
case Type::kStruct: {
std::vector<Operand> ops;
const Struct* structTy = type->AsStruct();
Expand Down Expand Up @@ -601,6 +614,13 @@ Type* TypeManager::RebuildType(uint32_t type_id, const Type& type) {
MakeUnique<RuntimeArray>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kNodePayloadArrayAMDX: {
const NodePayloadArrayAMDX* array_ty = type.AsNodePayloadArrayAMDX();
const Type* ele_ty = array_ty->element_type();
rebuilt_ty =
MakeUnique<NodePayloadArrayAMDX>(RebuildType(GetId(ele_ty), *ele_ty));
break;
}
case Type::kStruct: {
const Struct* struct_ty = type.AsStruct();
std::vector<const Type*> subtypes;
Expand Down Expand Up @@ -803,6 +823,14 @@ Type* TypeManager::RecordIfTypeDefinition(const Instruction& inst) {
return type;
}
break;
case spv::Op::OpTypeNodePayloadArrayAMDX:
type = new NodePayloadArrayAMDX(GetType(inst.GetSingleWordInOperand(0)));
if (id_to_incomplete_type_.count(inst.GetSingleWordInOperand(0))) {
incomplete_types_.emplace_back(inst.result_id(), type);
id_to_incomplete_type_[inst.result_id()] = type;
return type;
}
break;
case spv::Op::OpTypeStruct: {
std::vector<const Type*> element_types;
bool incomplete_type = false;
Expand Down Expand Up @@ -940,7 +968,8 @@ void TypeManager::AttachDecoration(const Instruction& inst, Type* type) {
if (!IsAnnotationInst(opcode)) return;

switch (opcode) {
case spv::Op::OpDecorate: {
case spv::Op::OpDecorate:
case spv::Op::OpDecorateId: {
const auto count = inst.NumOperands();
std::vector<uint32_t> data;
for (uint32_t i = 1; i < count; ++i) {
Expand Down
33 changes: 33 additions & 0 deletions source/opt/types.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -90,6 +92,7 @@ bool Type::IsUniqueType() const {
case kStruct:
case kArray:
case kRuntimeArray:
case kNodePayloadArrayAMDX:
return false;
default:
return true;
Expand Down Expand Up @@ -162,6 +165,7 @@ bool Type::operator==(const Type& other) const {
DeclareKindCase(SampledImage);
DeclareKindCase(Array);
DeclareKindCase(RuntimeArray);
DeclareKindCase(NodePayloadArrayAMDX);
DeclareKindCase(Struct);
DeclareKindCase(Opaque);
DeclareKindCase(Pointer);
Expand Down Expand Up @@ -218,6 +222,7 @@ size_t Type::ComputeHashValue(size_t hash, SeenTypes* seen) const {
DeclareKindCase(SampledImage);
DeclareKindCase(Array);
DeclareKindCase(RuntimeArray);
DeclareKindCase(NodePayloadArrayAMDX);
DeclareKindCase(Struct);
DeclareKindCase(Opaque);
DeclareKindCase(Pointer);
Expand Down Expand Up @@ -485,6 +490,34 @@ void RuntimeArray::ReplaceElementType(const Type* type) {
element_type_ = type;
}

NodePayloadArrayAMDX::NodePayloadArrayAMDX(const Type* type)
: Type(kNodePayloadArrayAMDX), element_type_(type) {
assert(!type->AsVoid());
}

bool NodePayloadArrayAMDX::IsSameImpl(const Type* that,
IsSameCache* seen) const {
const NodePayloadArrayAMDX* rat = that->AsNodePayloadArrayAMDX();
if (!rat) return false;
return element_type_->IsSameImpl(rat->element_type_, seen) &&
HasSameDecorations(that);
}

std::string NodePayloadArrayAMDX::str() const {
std::ostringstream oss;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason to use a ossringstream object? Given it's a basic routine which might be frequently called I suppose a str concat should be equivalent and more efficient. More so for smaller string values.
If we do have a need to use the ostringstream class obj perhaps consider reusing it by making it part of the class and clear/reset per use.

oss << "[" << element_type_->str() << "]";

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With larger sized strings perhaps using string_view (C++17) is better if a basic str concat is not an option.

return oss.str();
}

size_t NodePayloadArrayAMDX::ComputeExtraStateHash(size_t hash,
SeenTypes* seen) const {
return element_type_->ComputeHashValue(hash, seen);
}

void NodePayloadArrayAMDX::ReplaceElementType(const Type* type) {
element_type_ = type;
}

Struct::Struct(const std::vector<const Type*>& types)
: Type(kStruct), element_types_(types) {
for (const auto* t : types) {
Expand Down
28 changes: 28 additions & 0 deletions source/opt/types.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2016 Google Inc.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -46,6 +48,7 @@ class Sampler;
class SampledImage;
class Array;
class RuntimeArray;
class NodePayloadArrayAMDX;
class Struct;
class Opaque;
class Pointer;
Expand Down Expand Up @@ -87,6 +90,7 @@ class Type {
kSampledImage,
kArray,
kRuntimeArray,
kNodePayloadArrayAMDX,
kStruct,
kOpaque,
kPointer,
Expand Down Expand Up @@ -189,6 +193,7 @@ class Type {
DeclareCastMethod(SampledImage)
DeclareCastMethod(Array)
DeclareCastMethod(RuntimeArray)
DeclareCastMethod(NodePayloadArrayAMDX)
DeclareCastMethod(Struct)
DeclareCastMethod(Opaque)
DeclareCastMethod(Pointer)
Expand Down Expand Up @@ -434,6 +439,29 @@ class RuntimeArray : public Type {
const Type* element_type_;
};

class NodePayloadArrayAMDX : public Type {
public:
NodePayloadArrayAMDX(const Type* element_type);
NodePayloadArrayAMDX(const NodePayloadArrayAMDX&) = default;

std::string str() const override;
const Type* element_type() const { return element_type_; }

NodePayloadArrayAMDX* AsNodePayloadArrayAMDX() override { return this; }
const NodePayloadArrayAMDX* AsNodePayloadArrayAMDX() const override {
return this;
}

size_t ComputeExtraStateHash(size_t hash, SeenTypes* seen) const override;

void ReplaceElementType(const Type* element_type);

private:
bool IsSameImpl(const Type* that, IsSameCache*) const override;

const Type* element_type_;
};

class Struct : public Type {
public:
Struct(const std::vector<const Type*>& element_types);
Expand Down
7 changes: 7 additions & 0 deletions source/val/validate_annotation.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
// Copyright (c) 2018 Google LLC.
// Modifications Copyright (C) 2024 Advanced Micro Devices, Inc. All rights
// reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -30,6 +32,11 @@ bool DecorationTakesIdParameters(spv::Decoration type) {
case spv::Decoration::AlignmentId:
case spv::Decoration::MaxByteOffsetId:
case spv::Decoration::HlslCounterBufferGOOGLE:
case spv::Decoration::NodeMaxPayloadsAMDX:
case spv::Decoration::NodeSharesPayloadLimitsWithAMDX:
case spv::Decoration::PayloadNodeArraySizeAMDX:
case spv::Decoration::PayloadNodeNameAMDX:
case spv::Decoration::PayloadNodeBaseIndexAMDX:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These could all do with a simple test that OpDecoration is disallowed, but OpDecorationId is allowed.

return true;
default:
break;
Expand Down
Loading
Loading