Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize AOT CALL_INDIRECT for calls in the same module #555

Merged
merged 4 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions aot/src/main/java/com/dylibso/chicory/aot/AotEmitters.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static com.dylibso.chicory.aot.AotUtil.StackSize;
import static com.dylibso.chicory.aot.AotUtil.callIndirectMethodName;
import static com.dylibso.chicory.aot.AotUtil.callIndirectMethodType;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeFunction;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeStatic;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeVirtual;
import static com.dylibso.chicory.aot.AotUtil.emitJvmToLong;
Expand All @@ -42,8 +43,6 @@
import static com.dylibso.chicory.aot.AotUtil.jvmType;
import static com.dylibso.chicory.aot.AotUtil.loadTypeOpcode;
import static com.dylibso.chicory.aot.AotUtil.localType;
import static com.dylibso.chicory.aot.AotUtil.methodNameFor;
import static com.dylibso.chicory.aot.AotUtil.methodTypeFor;
import static com.dylibso.chicory.aot.AotUtil.stackSize;
import static com.dylibso.chicory.aot.AotUtil.storeTypeOpcode;
import static com.dylibso.chicory.aot.AotUtil.validateArgumentType;
Expand All @@ -54,7 +53,6 @@
import com.dylibso.chicory.wasm.types.FunctionType;
import com.dylibso.chicory.wasm.types.OpCode;
import com.dylibso.chicory.wasm.types.ValueType;
import java.lang.invoke.MethodType;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.Parameter;
Expand Down Expand Up @@ -128,18 +126,12 @@ public static void SELECT(AotContext ctx, AnnotatedInstruction ins, MethodVisito
public static void CALL(AotContext ctx, AnnotatedInstruction ins, MethodVisitor asm) {
int funcId = (int) ins.operand(0);
FunctionType functionType = ctx.functionTypes().get(funcId);
MethodType methodType = methodTypeFor(functionType);

emitInvokeStatic(asm, CHECK_INTERRUPTION);

asm.visitVarInsn(Opcodes.ALOAD, ctx.memorySlot());
asm.visitVarInsn(Opcodes.ALOAD, ctx.instanceSlot());
asm.visitMethodInsn(
Opcodes.INVOKESTATIC,
ctx.internalClassName(),
methodNameFor(funcId),
methodType.toMethodDescriptorString(),
false);
emitInvokeFunction(asm, ctx.internalClassName(), funcId, functionType);

if (functionType.returns().size() > 1) {
emitUnboxResult(asm, ctx, functionType.returns());
Expand All @@ -153,17 +145,16 @@ public static void CALL_INDIRECT(AotContext ctx, AnnotatedInstruction ins, Metho
int tableIdx = (int) ins.operand(1);
FunctionType functionType = ctx.types()[typeId];

MethodType methodType = callIndirectMethodType(functionType);

asm.visitLdcInsn(tableIdx);
asm.visitVarInsn(Opcodes.ALOAD, ctx.memorySlot());
asm.visitVarInsn(Opcodes.ALOAD, ctx.instanceSlot());
// stack: arguments, funcTableIdx, tableIdx, instance
// stack: arguments, funcTableIdx, tableIdx, memory, instance

asm.visitMethodInsn(
Opcodes.INVOKESTATIC,
ctx.internalClassName(),
callIndirectMethodName(typeId),
methodType.toMethodDescriptorString(),
callIndirectMethodType(functionType).toMethodDescriptorString(),
false);

if (functionType.returns().size() > 1) {
Expand Down
100 changes: 91 additions & 9 deletions aot/src/main/java/com/dylibso/chicory/aot/AotMachine.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@

import static com.dylibso.chicory.aot.AotMethods.CHECK_INTERRUPTION;
import static com.dylibso.chicory.aot.AotMethods.INSTANCE_CALL_HOST_FUNCTION;
import static com.dylibso.chicory.aot.AotMethods.INSTANCE_TABLE;
import static com.dylibso.chicory.aot.AotMethods.TABLE_INSTANCE;
import static com.dylibso.chicory.aot.AotMethods.TABLE_REF;
import static com.dylibso.chicory.aot.AotMethods.THROW_INDIRECT_CALL_TYPE_MISMATCH;
import static com.dylibso.chicory.aot.AotMethods.THROW_TRAP_EXCEPTION;
import static com.dylibso.chicory.aot.AotUtil.callIndirectMethodName;
import static com.dylibso.chicory.aot.AotUtil.callIndirectMethodType;
import static com.dylibso.chicory.aot.AotUtil.defaultValue;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeFunction;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeStatic;
import static com.dylibso.chicory.aot.AotUtil.emitInvokeVirtual;
import static com.dylibso.chicory.aot.AotUtil.emitJvmToLong;
Expand Down Expand Up @@ -57,6 +62,7 @@
import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodType;
import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
Expand Down Expand Up @@ -444,7 +450,7 @@ private byte[] compileClass(String className, FunctionSection functions) {
classWriter,
callIndirectMethodName(typeId),
callIndirectMethodType(type),
asm -> compileCallIndirect(asm, typeId, type));
asm -> compileCallIndirect(internalClassName, typeId, type, asm));
}

var returnTypes =
Expand Down Expand Up @@ -514,8 +520,6 @@ private Class<?> loadClass(String className, byte[] classBytes) {
try {
ClassReader reader = new ClassReader(classBytes);
CheckClassAdapter.verify(reader, true, new PrintWriter(System.out, false, UTF_8));
} catch (NoClassDefFoundError ignored) {
// the ASM verifier is an optional dependency
} catch (Throwable t) {
e.addSuppressed(t);
}
Expand Down Expand Up @@ -562,15 +566,93 @@ private static void emitConstructor(ClassVisitor writer) {
cons.visitEnd();
}

private static void compileCallIndirect(MethodVisitor asm, int typeId, FunctionType type) {
int slot = type.params().stream().mapToInt(AotUtil::slotCount).sum();
private void compileCallIndirect(
String internalClassName, int typeId, FunctionType type, MethodVisitor asm) {
int slots = type.params().stream().mapToInt(AotUtil::slotCount).sum();
int funcTableIdx = slots;
int tableIdx = slots + 1;
int memory = slots + 2;
int instance = slots + 3;
int table = slots + 4;
int funcId = slots + 5;
int refInstance = slots + 6;

emitInvokeStatic(asm, CHECK_INTERRUPTION);

// TableInstance table = instance.table(tableIdx);
asm.visitVarInsn(Opcodes.ALOAD, instance);
asm.visitVarInsn(Opcodes.ILOAD, tableIdx);
emitInvokeVirtual(asm, INSTANCE_TABLE);
asm.visitVarInsn(Opcodes.ASTORE, table);

// int funcId = tableRef(table, funcTableIdx);
asm.visitVarInsn(Opcodes.ALOAD, table);
asm.visitVarInsn(Opcodes.ILOAD, funcTableIdx);
emitInvokeStatic(asm, TABLE_REF);
asm.visitVarInsn(Opcodes.ISTORE, funcId);

// Instance refInstance = table.instance(funcTableIdx);
asm.visitVarInsn(Opcodes.ALOAD, table);
asm.visitVarInsn(Opcodes.ILOAD, funcTableIdx);
emitInvokeVirtual(asm, TABLE_INSTANCE);
asm.visitVarInsn(Opcodes.ASTORE, refInstance);

Label local = new Label();
Label other = new Label();

// if (refInstance == null || refInstance == instance)
asm.visitVarInsn(Opcodes.ALOAD, refInstance);
asm.visitJumpInsn(Opcodes.IFNULL, local);
asm.visitVarInsn(Opcodes.ALOAD, refInstance);
asm.visitVarInsn(Opcodes.ALOAD, instance);
asm.visitJumpInsn(Opcodes.IF_ACMPNE, other);

// local: call function in this module
asm.visitLabel(local);

int slot = 0;
for (ValueType param : type.params()) {
asm.visitVarInsn(loadTypeOpcode(param), slot);
slot += slotCount(param);
}
asm.visitVarInsn(Opcodes.ALOAD, memory);
asm.visitVarInsn(Opcodes.ALOAD, instance);

List<Integer> validIds = new ArrayList<>();
for (int i = 0; i < functionTypes.size(); i++) {
if (type.equals(functionTypes.get(i))) {
validIds.add(i);
}
}

Label invalid = new Label();
int[] keys = validIds.stream().mapToInt(x -> x).toArray();
Label[] labels = validIds.stream().map(x -> new Label()).toArray(Label[]::new);

asm.visitVarInsn(Opcodes.ILOAD, funcId);
asm.visitLookupSwitchInsn(invalid, keys, labels);

Label done = new Label();
for (int i = 0; i < validIds.size(); i++) {
asm.visitLabel(labels[i]);
emitInvokeFunction(asm, internalClassName, keys[i], type);
asm.visitJumpInsn(Opcodes.GOTO, done);
}

asm.visitLabel(invalid);
emitInvokeStatic(asm, THROW_INDIRECT_CALL_TYPE_MISMATCH);
asm.visitInsn(Opcodes.ATHROW);

asm.visitLabel(done);
asm.visitInsn(returnTypeOpcode(type));

// other: call function in another module
asm.visitLabel(other);

// parameters: arguments, funcTableIdx, tableIdx, instance
emitBoxArguments(asm, type.params());
asm.visitLdcInsn(typeId);
asm.visitVarInsn(Opcodes.ILOAD, slot); // funcTableIdx
asm.visitVarInsn(Opcodes.ILOAD, slot + 1); // tableIdx
asm.visitVarInsn(Opcodes.ALOAD, slot + 2); // instance
asm.visitVarInsn(Opcodes.ILOAD, funcId);
asm.visitVarInsn(Opcodes.ALOAD, refInstance);

emitInvokeStatic(asm, AotMethods.CALL_INDIRECT);

Expand Down
47 changes: 26 additions & 21 deletions aot/src/main/java/com/dylibso/chicory/aot/AotMethods.java
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package com.dylibso.chicory.aot;

import static com.dylibso.chicory.wasm.types.Value.REF_NULL_VALUE;
import static java.util.Objects.requireNonNullElse;

import com.dylibso.chicory.runtime.Instance;
import com.dylibso.chicory.runtime.Memory;
Expand All @@ -23,6 +22,7 @@ public final class AotMethods {
static final Method INSTANCE_READ_GLOBAL;
static final Method WRITE_GLOBAL;
static final Method INSTANCE_SET_ELEMENT;
static final Method INSTANCE_TABLE;
static final Method MEMORY_COPY;
static final Method MEMORY_FILL;
static final Method MEMORY_INIT;
Expand All @@ -49,7 +49,10 @@ public final class AotMethods {
static final Method TABLE_FILL;
static final Method TABLE_COPY;
static final Method TABLE_INIT;
static final Method TABLE_REF;
static final Method TABLE_INSTANCE;
static final Method VALIDATE_BASE;
static final Method THROW_INDIRECT_CALL_TYPE_MISMATCH;
static final Method THROW_OUT_OF_BOUNDS_MEMORY_ACCESS;
static final Method THROW_TRAP_EXCEPTION;

Expand All @@ -58,19 +61,15 @@ public final class AotMethods {
CHECK_INTERRUPTION = AotMethods.class.getMethod("checkInterruption");
CALL_INDIRECT =
AotMethods.class.getMethod(
"callIndirect",
long[].class,
int.class,
int.class,
int.class,
Instance.class);
"callIndirect", long[].class, int.class, int.class, Instance.class);
INSTANCE_CALL_HOST_FUNCTION =
Instance.class.getMethod("callHostFunction", int.class, long[].class);
INSTANCE_READ_GLOBAL = Instance.class.getMethod("readGlobal", int.class);
WRITE_GLOBAL =
AotMethods.class.getMethod(
"writeGlobal", long.class, int.class, Instance.class);
INSTANCE_SET_ELEMENT = Instance.class.getMethod("setElement", int.class, Element.class);
INSTANCE_TABLE = Instance.class.getMethod("table", int.class);
MEMORY_COPY =
AotMethods.class.getMethod(
"memoryCopy", int.class, int.class, int.class, Memory.class);
Expand Down Expand Up @@ -154,7 +153,11 @@ public final class AotMethods {
int.class,
int.class,
Instance.class);
TABLE_REF = AotMethods.class.getMethod("tableRef", TableInstance.class, int.class);
TABLE_INSTANCE = TableInstance.class.getMethod("instance", int.class);
VALIDATE_BASE = AotMethods.class.getMethod("validateBase", int.class);
THROW_INDIRECT_CALL_TYPE_MISMATCH =
AotMethods.class.getMethod("throwIndirectCallTypeMismatch");
THROW_OUT_OF_BOUNDS_MEMORY_ACCESS =
AotMethods.class.getMethod("throwOutOfBoundsMemoryAccess");
THROW_TRAP_EXCEPTION = AotMethods.class.getMethod("throwTrapException");
Expand All @@ -166,24 +169,12 @@ public final class AotMethods {
private AotMethods() {}

@UsedByGeneratedCode
public static long[] callIndirect(
long[] args, int typeId, int funcTableIdx, int tableIdx, Instance instance) {
TableInstance table = instance.table(tableIdx);

instance = requireNonNullElse(table.instance(funcTableIdx), instance);

int funcId = table.ref(funcTableIdx);
if (funcId == REF_NULL_VALUE) {
throw new ChicoryException("uninitialized element " + funcTableIdx);
}

public static long[] callIndirect(long[] args, int typeId, int funcId, Instance instance) {
FunctionType expectedType = instance.type(typeId);
FunctionType actualType = instance.type(instance.functionType(funcId));
if (!actualType.typesMatch(expectedType)) {
throw new ChicoryException("indirect call type mismatch");
throw throwIndirectCallTypeMismatch();
}

checkInterruption();
return instance.getMachine().call(funcId, args);
}

Expand Down Expand Up @@ -230,6 +221,15 @@ public static void tableInit(
OpcodeImpl.TABLE_INIT(instance, tableidx, elementidx, size, elemidx, offset);
}

@UsedByGeneratedCode
public static int tableRef(TableInstance table, int index) {
int funcId = table.ref(index);
if (funcId == REF_NULL_VALUE) {
throw new ChicoryException("uninitialized element " + index);
}
return funcId;
}

@UsedByGeneratedCode
public static void memoryCopy(int destination, int offset, int size, Memory memory) {
memory.copy(destination, offset, size);
Expand Down Expand Up @@ -326,6 +326,11 @@ public static void validateBase(int base) {
}
}

@UsedByGeneratedCode
public static RuntimeException throwIndirectCallTypeMismatch() {
return new ChicoryException("indirect call type mismatch");
}

@UsedByGeneratedCode
public static RuntimeException throwOutOfBoundsMemoryAccess() {
throw new WASMRuntimeException("out of bounds memory access");
Expand Down
12 changes: 11 additions & 1 deletion aot/src/main/java/com/dylibso/chicory/aot/AotUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ public static MethodHandle jvmToLongHandle(ValueType type) {

public static MethodType callIndirectMethodType(FunctionType functionType) {
return rawMethodTypeFor(functionType)
.appendParameterTypes(int.class, int.class, Instance.class);
.appendParameterTypes(int.class, int.class, Memory.class, Instance.class);
}

public static MethodType methodTypeFor(FunctionType type) {
Expand Down Expand Up @@ -319,6 +319,16 @@ public static void emitInvokeVirtual(MethodVisitor asm, Method method) {
false);
}

public static void emitInvokeFunction(
MethodVisitor asm, String internalClassName, int funcId, FunctionType functionType) {
asm.visitMethodInsn(
Opcodes.INVOKESTATIC,
internalClassName,
methodNameFor(funcId),
methodTypeFor(functionType).toMethodDescriptorString(),
false);
}

public static String methodNameFor(int funcId) {
return "func_" + funcId;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,10 @@ private static void verifyGeneratedBytecode(String name, HostFunction... hostFun
var out = new ByteArrayOutputStream();
cr.accept(new TraceClassVisitor(new PrintWriter(out, false, UTF_8)), 0);

Approvals.verify(out);
String output = out.toString(UTF_8);
output = output.replaceAll("(?m)^ {3}FRAME.*\\n", "");
output = output.replaceAll("(?m)^ {4}MAX(STACK|LOCALS) = \\d+\\n", "");

Approvals.verify(output);
}
}
Loading
Loading