Skip to content

Commit

Permalink
[xls] Ensure sitofp and friends work on tensors too
Browse files Browse the repository at this point in the history
Accidental omission from previous patch.

PiperOrigin-RevId: 682190931
  • Loading branch information
XLS Team authored and copybara-github committed Oct 4, 2024
1 parent a208c2f commit 87567ce
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
6 changes: 3 additions & 3 deletions xls/contrib/mlir/testdata/arith_to_xls.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,9 @@ func.func @ext(%arg0: i32) -> (i64, i64) attributes { "xls" = true } {
// CHECK-LABEL: @extf(
// CHECK: call_dslx
// CHECK-SAME: "ext"
func.func @extf(%arg0: bf16) -> f32 attributes { "xls" = true } {
%0 = arith.extf %arg0 : bf16 to f32
return %0 : f32
func.func @extf(%arg0: tensor<3x3xbf16>) -> tensor<3x3xf32> attributes { "xls" = true } {
%0 = arith.extf %arg0 : tensor<3x3xbf16> to tensor<3x3xf32>
return %0 : tensor<3x3xf32>
}

// CHECK-LABEL: @truncf(
Expand Down
13 changes: 8 additions & 5 deletions xls/contrib/mlir/transforms/arith_to_xls_patterns.td
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ def FloatLib : NativeCodeCall<"getFloatLib($0.getType())">;
// Shorthand for a constant string attribute.
class CS<string s> : ConstantStrAttr<StrAttr, s>;

class ScalarOrTensorOf<Type element> :
AnyTypeOf<[element, TensorOf<[element]>]>;

class BinaryOpPat<Arith_Op a, Xls_Op b> : Pat<(a $a, $b), (b $a, $b)>;
class BinaryOpOverflowPat<Arith_Op a, Xls_Op b> : Pat<(a $a, $b, $_), (b $a, $b)>;
class BinaryVariadicOpPat<Arith_Op a, Xls_Op b> : Pat<(a $a, $b), (b (variadic $a, $b))>;
Expand Down Expand Up @@ -131,27 +134,27 @@ def : Pat<(Arith_TruncFOp:$op $a, /*RoundingMode=*/$_, /*FastMathFlags=*/$_),

def : Pat<(Arith_SIToFPOp:$op I32:$a),
(Xls_CallDslxOp (FloatLib $op), CS<"from_int32">, (variadic $a), ConstUnitAttr),
[(F32 $op)]>;
[(ScalarOrTensorOf<F32> $op)]>;

def : Pat<(Arith_SIToFPOp:$op I8:$a),
(Xls_CallDslxOp (FloatLib $op), CS<"from_int8">, (variadic $a), ConstUnitAttr),
[(BF16 $op)]>;
[(ScalarOrTensorOf<BF16> $op)]>;

def : Pat<(Arith_FPToSIOp:$op F32:$a),
(Xls_CallDslxOp (FloatLib $a), CS<"to_int32">, (variadic $a), ConstUnitAttr),
[(I32 $op)]>;
[(ScalarOrTensorOf<I32> $op)]>;

def : Pat<(Arith_FPToSIOp:$op BF16:$a),
(Xls_CallDslxOp (FloatLib $a), CS<"to_int16">, (variadic $a), ConstUnitAttr),
[(I16 $op)]>;
[(ScalarOrTensorOf<I16> $op)]>;

// TODO(jmolloy): to_int8 doesn't exist, so truncating the result of to_int16
// seems like a reasonable approximation but I don't know if it's bit accurate.
def : Pat<(Arith_FPToSIOp:$op BF16:$a),
(Arith_TruncIOp
(Xls_CallDslxOp (FloatLib $a), CS<"to_int16">, (variadic $a),
ConstUnitAttr, (returnType "$_builder.getI16Type()"))),
[(I8 $op)]>;
[(ScalarOrTensorOf<I8> $op)]>;

// The expansion is a little tricky to read due to the one-hot select with the
// default case being the first argument.
Expand Down

0 comments on commit 87567ce

Please sign in to comment.