diff --git a/go.mod b/go.mod index 1a87e8e6e2..1ae761e761 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,6 @@ go 1.21 require ( github.com/antlr4-go/antlr/v4 v4.13.0 - github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 github.com/cubicdaiya/gonp v1.0.4 github.com/davecgh/go-spew v1.1.1 github.com/fatih/structtag v1.2.0 @@ -20,6 +19,7 @@ require ( github.com/riza-io/grpc-go v0.2.0 github.com/spf13/cobra v1.8.0 github.com/spf13/pflag v1.0.5 + github.com/tetratelabs/wazero v1.5.0 github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99 github.com/xeipuuv/gojsonschema v1.2.0 golang.org/x/sync v0.5.0 diff --git a/go.sum b/go.sum index 9b6ea6d981..d7c014ce5d 100644 --- a/go.sum +++ b/go.sum @@ -3,8 +3,6 @@ github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0 github.com/antlr4-go/antlr/v4 v4.13.0 h1:lxCg3LAv+EUK6t1i0y1V6/SLeUi0eKEKdhQAlS8TVTI= github.com/antlr4-go/antlr/v4 v4.13.0/go.mod h1:pfChB/xh/Unjila75QW7+VU4TSnWnnk9UTnmpPaOR2g= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= -github.com/bytecodealliance/wasmtime-go/v14 v14.0.0 h1:ur7S3P+PAeJmgllhSrKnGQOAmmtUbLQxb/nw2NZiaEM= -github.com/bytecodealliance/wasmtime-go/v14 v14.0.0/go.mod h1:tqOVEUjnXY6aGpSfM9qdVRR6G//Yc513fFYUdzZb/DY= github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= @@ -183,6 +181,12 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tetratelabs/wazero v1.5.0 h1:Yz3fZHivfDiZFUXnWMPUoiW7s8tC1sjdBtlJn08qYa0= +github.com/tetratelabs/wazero v1.5.0/go.mod h1:0U0G41+ochRKoPKCJlh0jMg1CHkyfK8kDqiirMmKY8A= +github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e h1:sGIC6/D0KqpA+qBSDSVDQswU/IJVYkbnUXnipgTLQWk= +github.com/wasilibs/go-pgquery v0.0.0-20231205013331-96e794bb074e/go.mod h1:KW0azBSWqkPZ71r+3O4qt8h6A/NisFLp0rbjZ3py4OE= +github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6 h1:jwbU8u5TuXModzdEG4wI0g4FyuD7ROSttU86go5sPdU= +github.com/wasilibs/wazerox v0.0.0-20231117065139-b3503f4aeff6/go.mod h1:IQNVyA4d1hWIe23mlMMuqXjyWMdndgSlNx6FqBkwPsM= github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99 h1:HFee1ByN4FrqNVd53Mo28ccGO+g5gxqUV/gdvKMe4b8= github.com/wasilibs/go-pgquery v0.0.0-20231208014744-de63626a1e99/go.mod h1:f2JMhFocVxY3VKMd9ykUxMnX4EVew9WOgjnfaNBB6C8= github.com/wasilibs/wazerox v0.0.0-20231208014050-e6b725634531 h1:zVJ4SZgaEE9sEH2L9k1+eAvCNa/WAAnT9UiMa3/tQrI= diff --git a/internal/endtoend/case_test.go b/internal/endtoend/case_test.go index 208b3fb9fa..50dcc57ec5 100644 --- a/internal/endtoend/case_test.go +++ b/internal/endtoend/case_test.go @@ -22,7 +22,6 @@ type Exec struct { Contexts []string `json:"contexts"` Process string `json:"process"` OS []string `json:"os"` - WASM bool `json:"wasm"` Env map[string]string `json:"env"` } diff --git a/internal/endtoend/endtoend_test.go b/internal/endtoend/endtoend_test.go index 2054baeee3..5753ce6d3a 100644 --- a/internal/endtoend/endtoend_test.go +++ b/internal/endtoend/endtoend_test.go @@ -16,7 +16,6 @@ import ( "github.com/sqlc-dev/sqlc/internal/cmd" "github.com/sqlc-dev/sqlc/internal/config" - "github.com/sqlc-dev/sqlc/internal/ext/wasm" "github.com/sqlc-dev/sqlc/internal/opts" ) @@ -177,10 +176,6 @@ func TestReplay(t *testing.T) { } } - if args.WASM && !wasm.Enabled() { - t.Skipf("wasm support not enabled") - } - if len(args.OS) > 0 { if !slices.Contains(args.OS, runtime.GOOS) { t.Skipf("unsupported os: %s", runtime.GOOS) diff --git a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json b/internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json deleted file mode 100644 index efe8bbc9aa..0000000000 --- a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_greeter/exec.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "wasm": true -} diff --git a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json b/internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json deleted file mode 100644 index efe8bbc9aa..0000000000 --- a/internal/endtoend/testdata/wasm_plugin_sqlc_gen_test/exec.json +++ /dev/null @@ -1,3 +0,0 @@ -{ - "wasm": true -} diff --git a/internal/ext/wasm/nowasm.go b/internal/ext/wasm/nowasm.go deleted file mode 100644 index 14af0b54a2..0000000000 --- a/internal/ext/wasm/nowasm.go +++ /dev/null @@ -1,23 +0,0 @@ -//go:build nowasm || !(cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64))) - -package wasm - -import ( - "context" - - "google.golang.org/grpc" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" -) - -func Enabled() bool { - return false -} - -func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, opts ...grpc.CallOption) error { - return status.Error(codes.FailedPrecondition, "sqlc built without wasmtime support") -} - -func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { - return nil, status.Error(codes.Unimplemented, codes.Unimplemented.String()) -} diff --git a/internal/ext/wasm/wasm.go b/internal/ext/wasm/wasm.go index c096ec9844..a14c71d8a4 100644 --- a/internal/ext/wasm/wasm.go +++ b/internal/ext/wasm/wasm.go @@ -1,10 +1,7 @@ -//go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) - -// The above build constraint is based of the cgo directives in this file: -// https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go package wasm import ( + "bytes" "context" "crypto/sha256" "errors" @@ -15,10 +12,11 @@ import ( "os" "path/filepath" "runtime" - "runtime/trace" "strings" - wasmtime "github.com/bytecodealliance/wasmtime-go/v14" + "github.com/tetratelabs/wazero" + "github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" + "github.com/tetratelabs/wazero/sys" "golang.org/x/sync/singleflight" "google.golang.org/grpc" "google.golang.org/grpc/codes" @@ -31,31 +29,13 @@ import ( "github.com/sqlc-dev/sqlc/internal/plugin" ) -func Enabled() bool { - return true -} - -// This version must be updated whenever the wasmtime-go dependency is updated -const wasmtimeVersion = `v14.0.0` +var flight singleflight.Group -func cacheDir() (string, error) { - cache := os.Getenv("SQLCCACHE") - if cache != "" { - return cache, nil - } - cacheHome := os.Getenv("XDG_CACHE_HOME") - if cacheHome == "" { - home, err := os.UserHomeDir() - if err != nil { - return "", err - } - cacheHome = filepath.Join(home, ".cache") - } - return filepath.Join(cacheHome, "sqlc"), nil +type runtimeAndCode struct { + rt wazero.Runtime + code wazero.CompiledModule } -var flight singleflight.Group - // Verify the provided sha256 is valid. func (r *Runner) getChecksum(ctx context.Context) (string, error) { if r.SHA256 != "" { @@ -70,67 +50,26 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) { return sum, nil } -func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) { +func (r *Runner) loadAndCompile(ctx context.Context) (*runtimeAndCode, error) { expected, err := r.getChecksum(ctx) if err != nil { return nil, err } - value, err, _ := flight.Do(expected, func() (interface{}, error) { - return r.loadSerializedModule(ctx, engine, expected) - }) - if err != nil { - return nil, err - } - data, ok := value.([]byte) - if !ok { - return nil, fmt.Errorf("returned value was not a byte slice") - } - return wasmtime.NewModuleDeserialize(engine, data) -} - -func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) { cacheDir, err := cache.PluginsDir() if err != nil { return nil, err } - - pluginDir := filepath.Join(cacheDir, expectedSha) - modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion) - modPath := filepath.Join(pluginDir, modName) - _, staterr := os.Stat(modPath) - if staterr == nil { - data, err := os.ReadFile(modPath) - if err != nil { - return nil, err - } - return data, nil - } - - wmod, err := r.loadWASM(ctx, cacheDir, expectedSha) + value, err, _ := flight.Do(expected, func() (interface{}, error) { + return r.loadAndCompileWASM(ctx, cacheDir, expected) + }) if err != nil { return nil, err } - - moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule") - module, err := wasmtime.NewModule(engine, wmod) - moduRegion.End() - if err != nil { - return nil, fmt.Errorf("define wasi: %w", err) - } - - err = os.Mkdir(pluginDir, 0755) - if err != nil && !os.IsExist(err) { - return nil, fmt.Errorf("mkdirall: %w", err) - } - out, err := module.Serialize() - if err != nil { - return nil, fmt.Errorf("serialize: %w", err) - } - if err := os.WriteFile(modPath, out, 0444); err != nil { - return nil, fmt.Errorf("cache wasm: %w", err) + data, ok := value.(*runtimeAndCode) + if !ok { + return nil, fmt.Errorf("returned value was not a compiled module") } - - return out, nil + return data, nil } func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) { @@ -174,7 +113,7 @@ func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) return wmod, actual, nil } -func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([]byte, error) { +func (r *Runner) loadAndCompileWASM(ctx context.Context, cache string, expected string) (*runtimeAndCode, error) { pluginDir := filepath.Join(cache, expected) pluginPath := filepath.Join(pluginDir, "plugin.wasm") _, staterr := os.Stat(pluginPath) @@ -203,7 +142,26 @@ func (r *Runner) loadWASM(ctx context.Context, cache string, expected string) ([ } } - return wmod, nil + wazeroCache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cache, "wazero")) + if err != nil { + return nil, fmt.Errorf("wazero.NewCompilationCacheWithDir: %w", err) + } + + config := wazero.NewRuntimeConfig().WithCompilationCache(wazeroCache) + rt := wazero.NewRuntimeWithConfig(ctx, config) + + if _, err := wasi_snapshot_preview1.Instantiate(ctx, rt); err != nil { + return nil, fmt.Errorf("wasi_snapshot_preview1 instantiate: %w", err) + } + + // Compile the Wasm binary once so that we can skip the entire compilation + // time during instantiation. + code, err := rt.CompileModule(ctx, wmod) + if err != nil { + return nil, fmt.Errorf("compile module: %w", err) + } + + return &runtimeAndCode{rt: rt, code: code}, nil } // removePGCatalog removes the pg_catalog schema from the request. There is a @@ -245,75 +203,34 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, return fmt.Errorf("failed to encode codegen request: %w", err) } - engine := wasmtime.NewEngine() - module, err := r.loadModule(ctx, engine) - if err != nil { - return fmt.Errorf("loadModule: %w", err) - } - - linker := wasmtime.NewLinker(engine) - if err := linker.DefineWasi(); err != nil { - return err - } - - dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out") + runtimeAndCode, err := r.loadAndCompile(ctx) if err != nil { - return fmt.Errorf("temp dir: %w", err) - } - - defer os.RemoveAll(dir) - stdinPath := filepath.Join(dir, "stdin") - stderrPath := filepath.Join(dir, "stderr") - stdoutPath := filepath.Join(dir, "stdout") - - if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil { - return fmt.Errorf("write file: %w", err) + return fmt.Errorf("loadBytes: %w", err) } - // Configure WASI imports to write stdout into a file. - wasiConfig := wasmtime.NewWasiConfig() - wasiConfig.SetArgv([]string{"plugin.wasm", method}) - wasiConfig.SetStdinFile(stdinPath) - wasiConfig.SetStdoutFile(stdoutPath) - wasiConfig.SetStderrFile(stderrPath) + var stderr, stdout bytes.Buffer - keys := []string{"SQLC_VERSION"} - vals := []string{info.Version} + conf := wazero.NewModuleConfig(). + WithName(""). + WithArgs("plugin.wasm", method). + WithStdin(bytes.NewReader(stdinBlob)). + WithStdout(&stdout). + WithStderr(&stderr). + WithEnv("SQLC_VERSION", info.Version) for _, key := range r.Env { - keys = append(keys, key) - vals = append(vals, os.Getenv(key)) - } - wasiConfig.SetEnv(keys, vals) - - store := wasmtime.NewStore(engine) - store.SetWasi(wasiConfig) - - linkRegion := trace.StartRegion(ctx, "linker.DefineModule") - err = linker.DefineModule(store, "", module) - linkRegion.End() - if err != nil { - return fmt.Errorf("define wasi: %w", err) + conf = conf.WithEnv(key, os.Getenv(key)) } - // Run the function - fn, err := linker.GetDefault(store, "") - if err != nil { - return fmt.Errorf("wasi: get default: %w", err) + result, err := runtimeAndCode.rt.InstantiateModule(ctx, runtimeAndCode.code, conf) + if result != nil { + defer result.Close(ctx) } - - callRegion := trace.StartRegion(ctx, "call _start") - _, err = fn.Call(store) - callRegion.End() - - if cerr := checkError(err, stderrPath); cerr != nil { + if cerr := checkError(err, stderr); cerr != nil { return cerr } // Print WASM stdout - stdoutBlob, err := os.ReadFile(stdoutPath) - if err != nil { - return fmt.Errorf("read file: %w", err) - } + stdoutBlob := stdout.Bytes() resp, ok := reply.(protoreflect.ProtoMessage) if !ok { @@ -331,23 +248,21 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st return nil, status.Error(codes.Unimplemented, "") } -func checkError(err error, stderrPath string) error { +func checkError(err error, stderr bytes.Buffer) error { if err == nil { return err } - var wtError *wasmtime.Error - if errors.As(err, &wtError) { - if code, ok := wtError.ExitStatus(); ok { - if code == 0 { - return nil - } + if exitErr, ok := err.(*sys.ExitError); ok { + if exitErr.ExitCode() == 0 { + return nil } } + // Print WASM stdout - stderrBlob, rferr := os.ReadFile(stderrPath) - if rferr == nil && len(stderrBlob) > 0 { - return errors.New(string(stderrBlob)) + stderrBlob := stderr.String() + if len(stderrBlob) > 0 { + return errors.New(stderrBlob) } return fmt.Errorf("call: %w", err) }