-
Notifications
You must be signed in to change notification settings - Fork 843
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat(plugins): Use wazero instead of wasmtime #3042
Changes from 2 commits
5e3d938
2878527
7517d03
89f4d8b
0914cad
245563d
e3cc250
18fbb55
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,7 @@ | ||
//go:build !nowasm && cgo && ((linux && amd64) || (linux && arm64) || (darwin && amd64) || (darwin && arm64) || (windows && amd64)) | ||
|
||
// The above build constraint is based of the cgo directives in this file: | ||
// https://github.com/bytecodealliance/wasmtime-go/blob/main/ffi.go | ||
package wasm | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"crypto/sha256" | ||
"errors" | ||
|
@@ -15,10 +12,11 @@ import ( | |
"os" | ||
"path/filepath" | ||
"runtime" | ||
"runtime/trace" | ||
"strings" | ||
|
||
wasmtime "github.com/bytecodealliance/wasmtime-go/v14" | ||
"github.com/tetratelabs/wazero" | ||
"github.com/tetratelabs/wazero/imports/wasi_snapshot_preview1" | ||
"github.com/tetratelabs/wazero/sys" | ||
"golang.org/x/sync/singleflight" | ||
"google.golang.org/grpc" | ||
"google.golang.org/grpc/codes" | ||
|
@@ -31,13 +29,6 @@ import ( | |
"github.com/sqlc-dev/sqlc/internal/plugin" | ||
) | ||
|
||
func Enabled() bool { | ||
return true | ||
} | ||
|
||
// This version must be updated whenever the wasmtime-go dependency is updated | ||
const wasmtimeVersion = `v14.0.0` | ||
|
||
func cacheDir() (string, error) { | ||
cache := os.Getenv("SQLCCACHE") | ||
if cache != "" { | ||
|
@@ -70,13 +61,17 @@ func (r *Runner) getChecksum(ctx context.Context) (string, error) { | |
return sum, nil | ||
} | ||
|
||
func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasmtime.Module, error) { | ||
func (r *Runner) loadBytes(ctx context.Context) ([]byte, error) { | ||
expected, err := r.getChecksum(ctx) | ||
if err != nil { | ||
return nil, err | ||
} | ||
cacheDir, err := cache.PluginsDir() | ||
if err != nil { | ||
return nil, err | ||
} | ||
value, err, _ := flight.Do(expected, func() (interface{}, error) { | ||
return r.loadSerializedModule(ctx, engine, expected) | ||
return r.loadWASM(ctx, cacheDir, expected) | ||
}) | ||
if err != nil { | ||
return nil, err | ||
|
@@ -85,52 +80,7 @@ func (r *Runner) loadModule(ctx context.Context, engine *wasmtime.Engine) (*wasm | |
if !ok { | ||
return nil, fmt.Errorf("returned value was not a byte slice") | ||
} | ||
return wasmtime.NewModuleDeserialize(engine, data) | ||
} | ||
|
||
func (r *Runner) loadSerializedModule(ctx context.Context, engine *wasmtime.Engine, expectedSha string) ([]byte, error) { | ||
cacheDir, err := cache.PluginsDir() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
pluginDir := filepath.Join(cacheDir, expectedSha) | ||
modName := fmt.Sprintf("plugin_%s_%s_%s.module", runtime.GOOS, runtime.GOARCH, wasmtimeVersion) | ||
modPath := filepath.Join(pluginDir, modName) | ||
_, staterr := os.Stat(modPath) | ||
if staterr == nil { | ||
data, err := os.ReadFile(modPath) | ||
if err != nil { | ||
return nil, err | ||
} | ||
return data, nil | ||
} | ||
|
||
wmod, err := r.loadWASM(ctx, cacheDir, expectedSha) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
moduRegion := trace.StartRegion(ctx, "wasmtime.NewModule") | ||
module, err := wasmtime.NewModule(engine, wmod) | ||
moduRegion.End() | ||
if err != nil { | ||
return nil, fmt.Errorf("define wasi: %w", err) | ||
} | ||
|
||
err = os.Mkdir(pluginDir, 0755) | ||
if err != nil && !os.IsExist(err) { | ||
return nil, fmt.Errorf("mkdirall: %w", err) | ||
} | ||
out, err := module.Serialize() | ||
if err != nil { | ||
return nil, fmt.Errorf("serialize: %w", err) | ||
} | ||
if err := os.WriteFile(modPath, out, 0444); err != nil { | ||
return nil, fmt.Errorf("cache wasm: %w", err) | ||
} | ||
|
||
return out, nil | ||
return data, nil | ||
} | ||
|
||
func (r *Runner) fetch(ctx context.Context, uri string) ([]byte, string, error) { | ||
|
@@ -245,72 +195,56 @@ func (r *Runner) Invoke(ctx context.Context, method string, args any, reply any, | |
return fmt.Errorf("failed to encode codegen request: %w", err) | ||
} | ||
|
||
engine := wasmtime.NewEngine() | ||
module, err := r.loadModule(ctx, engine) | ||
cacheDir, err := cache.PluginsDir() | ||
if err != nil { | ||
return fmt.Errorf("loadModule: %w", err) | ||
return err | ||
} | ||
|
||
linker := wasmtime.NewLinker(engine) | ||
if err := linker.DefineWasi(); err != nil { | ||
cache, err := wazero.NewCompilationCacheWithDir(filepath.Join(cacheDir, "wazero")) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
dir, err := os.MkdirTemp(os.Getenv("SQLCTMPDIR"), "out") | ||
wasmBytes, err := r.loadBytes(ctx) | ||
if err != nil { | ||
return fmt.Errorf("temp dir: %w", err) | ||
return fmt.Errorf("loadModule: %w", err) | ||
} | ||
|
||
defer os.RemoveAll(dir) | ||
stdinPath := filepath.Join(dir, "stdin") | ||
stderrPath := filepath.Join(dir, "stderr") | ||
stdoutPath := filepath.Join(dir, "stdout") | ||
|
||
if err := os.WriteFile(stdinPath, stdinBlob, 0755); err != nil { | ||
return fmt.Errorf("write file: %w", err) | ||
} | ||
config := wazero.NewRuntimeConfig().WithCompilationCache(cache) | ||
rt := wazero.NewRuntimeWithConfig(ctx, config) | ||
defer rt.Close(ctx) | ||
|
||
// Configure WASI imports to write stdout into a file. | ||
wasiConfig := wasmtime.NewWasiConfig() | ||
wasiConfig.SetArgv([]string{"plugin.wasm", method}) | ||
wasiConfig.SetStdinFile(stdinPath) | ||
wasiConfig.SetStdoutFile(stdoutPath) | ||
wasiConfig.SetStderrFile(stderrPath) | ||
// TODO: Handle error | ||
wasi_snapshot_preview1.MustInstantiate(ctx, rt) | ||
|
||
keys := []string{"SQLC_VERSION"} | ||
vals := []string{info.Version} | ||
for _, key := range r.Env { | ||
keys = append(keys, key) | ||
vals = append(vals, os.Getenv(key)) | ||
// Compile the Wasm binary once so that we can skip the entire compilation time during instantiation. | ||
mod, err := rt.CompileModule(ctx, wasmBytes) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If it's possible, it would be nice to rejigger to scope this to |
||
if err != nil { | ||
return err | ||
} | ||
wasiConfig.SetEnv(keys, vals) | ||
|
||
store := wasmtime.NewStore(engine) | ||
store.SetWasi(wasiConfig) | ||
var stderr, stdout bytes.Buffer | ||
|
||
linkRegion := trace.StartRegion(ctx, "linker.DefineModule") | ||
err = linker.DefineModule(store, "", module) | ||
linkRegion.End() | ||
if err != nil { | ||
return fmt.Errorf("define wasi: %w", err) | ||
conf := wazero.NewModuleConfig() | ||
conf = conf.WithArgs("plugin.wasm", method) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit, consider chaining, it's arguably idiomatic for wazero users
|
||
conf = conf.WithEnv("SQLC_VERSION", info.Version) | ||
for _, key := range r.Env { | ||
conf = conf.WithEnv(key, os.Getenv(key)) | ||
} | ||
conf = conf.WithStdin(bytes.NewReader(stdinBlob)) | ||
conf = conf.WithStdout(&stdout) | ||
conf = conf.WithStderr(&stderr) | ||
|
||
// Run the function | ||
fn, err := linker.GetDefault(store, "") | ||
if err != nil { | ||
return fmt.Errorf("wasi: get default: %w", err) | ||
result, err := rt.InstantiateModule(ctx, mod, conf) | ||
if result != nil { | ||
defer result.Close(ctx) | ||
} | ||
|
||
callRegion := trace.StartRegion(ctx, "call _start") | ||
_, err = fn.Call(store) | ||
callRegion.End() | ||
|
||
if cerr := checkError(err, stderrPath); cerr != nil { | ||
if cerr := checkError(err, &stderr); cerr != nil { | ||
return cerr | ||
} | ||
|
||
// Print WASM stdout | ||
stdoutBlob, err := os.ReadFile(stdoutPath) | ||
stdoutBlob, err := io.ReadAll(&stdout) | ||
if err != nil { | ||
return fmt.Errorf("read file: %w", err) | ||
} | ||
kyleconroy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
@@ -331,21 +265,19 @@ func (r *Runner) NewStream(ctx context.Context, desc *grpc.StreamDesc, method st | |
return nil, status.Error(codes.Unimplemented, "") | ||
} | ||
|
||
func checkError(err error, stderrPath string) error { | ||
func checkError(err error, stderr io.Reader) error { | ||
if err == nil { | ||
return err | ||
} | ||
|
||
var wtError *wasmtime.Error | ||
if errors.As(err, &wtError) { | ||
if code, ok := wtError.ExitStatus(); ok { | ||
if code == 0 { | ||
return nil | ||
} | ||
if exitErr, ok := err.(*sys.ExitError); ok { | ||
if exitErr.ExitCode() == 0 { | ||
return nil | ||
} | ||
} | ||
|
||
// Print WASM stdout | ||
stderrBlob, rferr := os.ReadFile(stderrPath) | ||
stderrBlob, rferr := io.ReadAll(stderr) | ||
if rferr == nil && len(stderrBlob) > 0 { | ||
return errors.New(string(stderrBlob)) | ||
} | ||
kyleconroy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is
Instantiate
if you'd like to return the error. Though I think any failure here would be a programming bug, not non-determinstic