diff --git a/bundle/bundle.go b/bundle/bundle.go index a68c3c7125..816f5535fc 100644 --- a/bundle/bundle.go +++ b/bundle/bundle.go @@ -450,6 +450,7 @@ type Reader struct { name string persist bool regoVersion ast.RegoVersion + followSymlinks bool } // NewReader is deprecated. Use NewCustomReader instead. @@ -538,6 +539,11 @@ func (r *Reader) WithBundleName(name string) *Reader { return r } +func (r *Reader) WithFollowSymlinks(yes bool) *Reader { + r.followSymlinks = yes + return r +} + // WithLazyLoadingMode sets the bundle loading mode. If true, // bundles will be read in lazy mode. In this mode, data files in the bundle will not be // deserialized and the check to validate that the bundle data does not contain paths diff --git a/bundle/file.go b/bundle/file.go index c2c5a6b849..80b1a87eb1 100644 --- a/bundle/file.go +++ b/bundle/file.go @@ -6,6 +6,7 @@ import ( "compress/gzip" "fmt" "io" + "io/fs" "os" "path/filepath" "sort" @@ -126,6 +127,7 @@ type DirectoryLoader interface { WithFilter(filter filter.LoaderFilter) DirectoryLoader WithPathFormat(PathFormat) DirectoryLoader WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader + WithFollowSymlinks(followSymlinks bool) DirectoryLoader } type dirLoader struct { @@ -135,6 +137,7 @@ type dirLoader struct { filter filter.LoaderFilter pathFormat PathFormat maxSizeLimitBytes int64 + followSymlinks bool } // Normalize root directory, ex "./src/bundle" -> "src/bundle" @@ -181,6 +184,12 @@ func (d *dirLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { return d } +// WithFollowSymlinks specifies whether to follow symlinks when loading files from the directory +func (d *dirLoader) WithFollowSymlinks(followSymlinks bool) DirectoryLoader { + d.followSymlinks = followSymlinks + return d +} + func formatPath(fileName string, root string, pathFormat PathFormat) string { switch pathFormat { case SlashRooted: @@ -212,7 +221,11 @@ func (d *dirLoader) NextFile() (*Descriptor, error) { if d.files == nil { d.files = []string{} err := filepath.Walk(d.root, func(path string, info os.FileInfo, _ error) error { - if info != nil && info.Mode().IsRegular() { + if info == nil { + return nil + } + + if info.Mode().IsRegular() { if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { return nil } @@ -220,7 +233,15 @@ func (d *dirLoader) NextFile() (*Descriptor, error) { return fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(path, "/"), info.Size(), d.maxSizeLimitBytes) } d.files = append(d.files, path) - } else if info != nil && info.Mode().IsDir() { + } else if d.followSymlinks && info.Mode().Type()&fs.ModeSymlink == fs.ModeSymlink { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { + return nil + } + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf(maxSizeLimitBytesErrMsg, strings.TrimPrefix(path, "/"), info.Size(), d.maxSizeLimitBytes) + } + d.files = append(d.files, path) + } else if info.Mode().IsDir() { if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { return filepath.SkipDir } @@ -305,6 +326,11 @@ func (t *tarballLoader) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader return t } +// WithFollowSymlinks is a no-op for tarballLoader +func (t *tarballLoader) WithFollowSymlinks(_ bool) DirectoryLoader { + return t +} + // NextFile iterates to the next file in the directory tree // and returns a file Descriptor for the file. func (t *tarballLoader) NextFile() (*Descriptor, error) { diff --git a/bundle/filefs.go b/bundle/filefs.go index e8767e1bae..a3a0dbf204 100644 --- a/bundle/filefs.go +++ b/bundle/filefs.go @@ -26,6 +26,7 @@ type dirLoaderFS struct { root string pathFormat PathFormat maxSizeLimitBytes int64 + followSymlinks bool } // NewFSLoader returns a basic DirectoryLoader implementation @@ -66,6 +67,16 @@ func (d *dirLoaderFS) walkDir(path string, dirEntry fs.DirEntry, err error) erro return fmt.Errorf("file %s size %d exceeds limit of %d", path, info.Size(), d.maxSizeLimitBytes) } + d.files = append(d.files, path) + } else if dirEntry.Type()&fs.ModeSymlink != 0 && d.followSymlinks { + if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, false)) { + return nil + } + + if d.maxSizeLimitBytes > 0 && info.Size() > d.maxSizeLimitBytes { + return fmt.Errorf("file %s size %d exceeds limit of %d", path, info.Size(), d.maxSizeLimitBytes) + } + d.files = append(d.files, path) } else if dirEntry.Type().IsDir() { if d.filter != nil && d.filter(filepath.ToSlash(path), info, getdepth(path, true)) { @@ -94,6 +105,11 @@ func (d *dirLoaderFS) WithSizeLimitBytes(sizeLimitBytes int64) DirectoryLoader { return d } +func (d *dirLoaderFS) WithFollowSymlinks(followSymlinks bool) DirectoryLoader { + d.followSymlinks = followSymlinks + return d +} + // NextFile iterates to the next file in the directory tree // and returns a file Descriptor for the file. func (d *dirLoaderFS) NextFile() (*Descriptor, error) { diff --git a/cmd/build.go b/cmd/build.go index fffb8f2a37..7c0b450684 100644 --- a/cmd/build.go +++ b/cmd/build.go @@ -45,6 +45,7 @@ type buildParams struct { plugin string ns string v1Compatible bool + followSymlinks bool } func newBuildParams() buildParams { @@ -238,6 +239,7 @@ against OPA v0.22.0: buildCommand.Flags().VarP(&buildParams.revision, "revision", "r", "set output bundle revision") buildCommand.Flags().StringVarP(&buildParams.outputFile, "output", "o", "bundle.tar.gz", "set the output filename") buildCommand.Flags().StringVar(&buildParams.ns, "partial-namespace", "partial", "set the namespace to use for partially evaluated files in an optimized bundle") + buildCommand.Flags().BoolVar(&buildParams.followSymlinks, "follow-symlinks", false, "follow symlinks in the input set of paths when building the bundle") addBundleModeFlag(buildCommand.Flags(), &buildParams.bundleMode, false) addIgnoreFlag(buildCommand.Flags(), &buildParams.ignore) @@ -302,7 +304,8 @@ func dobuild(params buildParams, args []string) error { WithFilter(buildCommandLoaderFilter(params.bundleMode, params.ignore)). WithBundleVerificationConfig(bvc). WithBundleSigningConfig(bsc). - WithPartialNamespace(params.ns) + WithPartialNamespace(params.ns). + WithFollowSymlinks(params.followSymlinks) if params.v1Compatible { compiler = compiler.WithRegoVersion(ast.RegoV1) diff --git a/cmd/build_test.go b/cmd/build_test.go index 8a47d66109..3dc1bcfbb6 100644 --- a/cmd/build_test.go +++ b/cmd/build_test.go @@ -2002,3 +2002,229 @@ foo contains __local1__1 if { }) } } + +// TestBuildWithFollowSymlinks tests that the build command follows symlinks when building a bundle. +// This test uses a local tmp filesystem to create a directory with a symlink to a file in it's root +// and a local file in the bundle directory, and verifies that the built bundle contains both the symlink +// and the regular file. +// There's probably some common utilities that could be extracted at some point but for now this code is +// local to the test until we need to reuse it elsewhere. +func TestBuildWithFollowSymlinks(t *testing.T) { + rootDir, err := os.MkdirTemp("", "build-follow-symlinks") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := os.RemoveAll(rootDir); err != nil { + t.Fatal(err) + } + }() + bundleDir := path.Join(rootDir, "bundle") + err = os.Mkdir(bundleDir, 0777) + if err != nil { + t.Fatal(err) + } + + // create a regular file in our temp bundle directory + err = os.WriteFile(filepath.Join(bundleDir, "foo.rego"), []byte("package foo\none = 1"), 0777) + if err != nil { + t.Fatal(err) + } + + // create a regular file in the root directory of our tmp directory that we will symlink into the bundle directory later + err = os.WriteFile(filepath.Join(rootDir, "bar.rego"), []byte("package foo\ntwo = 2"), 0777) + if err != nil { + t.Fatal(err) + } + + // create a symlink in the bundle directory to the file in the root directory + err = os.Symlink(filepath.Join(rootDir, "bar.rego"), filepath.Join(bundleDir, "bar.rego")) + if err != nil { + t.Fatal(err) + } + + params := newBuildParams() + params.outputFile = path.Join(rootDir, "test.tar.gz") + params.bundleMode = true + params.followSymlinks = true + + err = dobuild(params, []string{bundleDir}) + if err != nil { + t.Fatal(err) + } + + // verify that the bundle is a loadable bundle + _, err = loader.NewFileLoader().AsBundle(params.outputFile) + if err != nil { + t.Fatal(err) + } + + f, err := os.Open(params.outputFile) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + gr, err := gzip.NewReader(f) + if err != nil { + t.Fatal(err) + } + + tr := tar.NewReader(gr) + + // map of file name -> file content + expectedFiles := map[string]string{ + bundleDir + "/foo.rego": "package foo\n\none = 1", + bundleDir + "/bar.rego": "package foo\n\ntwo = 2", + "/.manifest": `{"revision":"","roots":[""],"rego_version":0}`, + "/data.json": "{}", + } + + foundFiles := make(map[string]string, 4) + for f, err := tr.Next(); err != io.EOF; f, err = tr.Next() { + if err != nil { + t.Fatal(err) + } + + // ensure that all the files are regular files in the bundle + // and that no symlinks were copied + if mode := f.FileInfo().Mode(); !mode.IsRegular() { + t.Fatalf("expected regular file for file %s but got %s", f.FileInfo().Name(), mode.String()) + } + // read the file content + data, err := io.ReadAll(tr) + if err != nil { + t.Fatalf("failed to read file %s: %v", f.FileInfo().Name(), err) + } + foundFiles[f.Name] = string(data) + } + + if len(foundFiles) != 4 { + t.Fatalf("expected four files in bundle but got %d", len(foundFiles)) + } + + for name, contents := range foundFiles { + // trim added whitespace because it's annoying and makes the test less readable + contents := strings.Trim(contents, "\n") + // check that the file content matches the expected content + expectedContent, ok := expectedFiles[name] + if !ok { + t.Fatalf("unexpected file %s in bundle", name) + } + + if contents != expectedContent { + t.Fatalf("expected file %s to contain:\n\n%v\n\ngot:\n\n%v", name, expectedContent, contents) + } + } +} + +// TestBuildWithFollowSymlinksEntireDir tests that the build command can build a bundle from a symlinked directory. +// This test uses a local tmp filesystem to create a directory with a local file in the bundle directory, and +// verifies that the built bundle contains the files from the symlinked directory. +func TestBuildWithFollowSymlinksEntireDir(t *testing.T) { + rootDir, err := os.MkdirTemp("", "build-follow-symlinks-dir") + if err != nil { + t.Fatal(err) + } + defer func() { + if err := os.RemoveAll(rootDir); err != nil { + t.Fatal(err) + } + }() + bundleDir := path.Join(rootDir, "src") + err = os.Mkdir(bundleDir, 0777) + if err != nil { + t.Fatal(err) + } + + // create a regular file in our temp bundle directory + err = os.WriteFile(filepath.Join(bundleDir, "foo.rego"), []byte("package foo\none = 1"), 0777) + if err != nil { + t.Fatal(err) + } + + symlinkDir := path.Join(rootDir, "symlink") + err = os.Mkdir(symlinkDir, 0777) + if err != nil { + t.Fatal(err) + } + + // create a symlink in the symlink directory to the src directory + err = os.Symlink(bundleDir, filepath.Join(symlinkDir, "linked")) + if err != nil { + t.Fatal(err) + } + + params := newBuildParams() + params.outputFile = path.Join(rootDir, "test.tar.gz") + params.bundleMode = true + params.followSymlinks = true + + err = dobuild(params, []string{symlinkDir + "/linked/"}) + if err != nil { + t.Fatal(err) + } + + // verify that the bundle is a loadable bundle + _, err = loader.NewFileLoader().AsBundle(params.outputFile) + if err != nil { + t.Fatal(err) + } + + f, err := os.Open(params.outputFile) + if err != nil { + t.Fatal(err) + } + defer f.Close() + + gr, err := gzip.NewReader(f) + if err != nil { + t.Fatal(err) + } + + tr := tar.NewReader(gr) + + // map of file name -> file content + expectedFiles := map[string]string{ + path.Join(symlinkDir, "linked", "foo.rego"): "package foo\n\none = 1", + "/.manifest": `{"revision":"","roots":[""],"rego_version":0}`, + "/data.json": "{}", + } + + foundFiles := make(map[string]string, 3) + for f, err := tr.Next(); err != io.EOF; f, err = tr.Next() { + if err != nil { + t.Fatal(err) + } + + // ensure that all the files are regular files in the bundle + // and that no symlinks were copied + if mode := f.FileInfo().Mode(); !mode.IsRegular() { + t.Fatalf("expected regular file for file %s but got %s", f.FileInfo().Name(), mode.String()) + } + // read the file content + data, err := io.ReadAll(tr) + if err != nil { + t.Fatalf("failed to read file %s: %v", f.FileInfo().Name(), err) + } + foundFiles[f.Name] = string(data) + } + + if len(foundFiles) != 3 { + t.Fatalf("expected three files in bundle but got %d", len(foundFiles)) + } + + for name, contents := range foundFiles { + // trim added whitespace because it's annoying and makes the test less readable + contents := strings.Trim(contents, "\n") + // check that the file content matches the expected content + expectedContent, ok := expectedFiles[name] + if !ok { + t.Fatalf("unexpected file %s in bundle", name) + } + + if contents != expectedContent { + t.Fatalf("expected file %s to contain:\n\n%v\n\ngot:\n\n%v", name, expectedContent, contents) + } + } +} diff --git a/compile/compile.go b/compile/compile.go index 9e80d01988..a0ac6fc4be 100644 --- a/compile/compile.go +++ b/compile/compile.go @@ -84,6 +84,7 @@ type Compiler struct { fsys fs.FS // file system to use when loading paths ns string regoVersion ast.RegoVersion + followSymlinks bool // optionally follow symlinks in the bundle directory when building the bundle } // New returns a new compiler instance that can be invoked. @@ -219,6 +220,12 @@ func (c *Compiler) WithCapabilities(capabilities *ast.Capabilities) *Compiler { return c } +// WithFollowSymlinks sets whether or not to follow symlinks in the bundle directory when building the bundle +func (c *Compiler) WithFollowSymlinks(yes bool) *Compiler { + c.followSymlinks = yes + return c +} + // WithMetadata sets the additional data to be included in .manifest func (c *Compiler) WithMetadata(metadata *map[string]interface{}) *Compiler { c.metadata = metadata @@ -471,7 +478,17 @@ func (c *Compiler) initBundle(usePath bool) error { // TODO(tsandall): the metrics object should passed through here so we that // we can track read and parse times. - load, err := initload.LoadPathsForRegoVersion(c.regoVersion, c.paths, c.filter, c.asBundle, c.bvc, false, c.useRegoAnnotationEntrypoints, c.capabilities, c.fsys) + load, err := initload.LoadPathsForRegoVersion( + c.regoVersion, + c.paths, + c.filter, + c.asBundle, + c.bvc, + false, + c.useRegoAnnotationEntrypoints, + c.followSymlinks, + c.capabilities, + c.fsys) if err != nil { return fmt.Errorf("load error: %w", err) } diff --git a/internal/pathwatcher/utils.go b/internal/pathwatcher/utils.go index 29a64c9079..31319a9ced 100644 --- a/internal/pathwatcher/utils.go +++ b/internal/pathwatcher/utils.go @@ -47,7 +47,7 @@ func ProcessWatcherUpdate(ctx context.Context, paths []string, removed string, s func ProcessWatcherUpdateForRegoVersion(ctx context.Context, regoVersion ast.RegoVersion, paths []string, removed string, store storage.Store, filter loader.Filter, asBundle bool, f func(context.Context, storage.Transaction, *initload.LoadPathsResult) error) error { - loaded, err := initload.LoadPathsForRegoVersion(regoVersion, paths, filter, asBundle, nil, true, false, nil, nil) + loaded, err := initload.LoadPathsForRegoVersion(regoVersion, paths, filter, asBundle, nil, true, false, false, nil, nil) if err != nil { return err } diff --git a/internal/runtime/init/init.go b/internal/runtime/init/init.go index 88e8bc4e0b..b1a5b71577 100644 --- a/internal/runtime/init/init.go +++ b/internal/runtime/init/init.go @@ -124,7 +124,7 @@ func LoadPaths(paths []string, processAnnotations bool, caps *ast.Capabilities, fsys fs.FS) (*LoadPathsResult, error) { - return LoadPathsForRegoVersion(ast.RegoV0, paths, filter, asBundle, bvc, skipVerify, processAnnotations, caps, fsys) + return LoadPathsForRegoVersion(ast.RegoV0, paths, filter, asBundle, bvc, skipVerify, processAnnotations, false, caps, fsys) } func LoadPathsForRegoVersion(regoVersion ast.RegoVersion, @@ -134,6 +134,7 @@ func LoadPathsForRegoVersion(regoVersion ast.RegoVersion, bvc *bundle.VerificationConfig, skipVerify bool, processAnnotations bool, + followSymlinks bool, caps *ast.Capabilities, fsys fs.FS) (*LoadPathsResult, error) { @@ -161,6 +162,7 @@ func LoadPathsForRegoVersion(regoVersion ast.RegoVersion, WithProcessAnnotation(processAnnotations). WithCapabilities(caps). WithRegoVersion(regoVersion). + WithFollowSymlinks(followSymlinks). AsBundle(path) if err != nil { return nil, err diff --git a/loader/loader.go b/loader/loader.go index e584bab3cb..759fc9b05d 100644 --- a/loader/loader.go +++ b/loader/loader.go @@ -103,6 +103,7 @@ type FileLoader interface { WithCapabilities(*ast.Capabilities) FileLoader WithJSONOptions(*astJSON.Options) FileLoader WithRegoVersion(ast.RegoVersion) FileLoader + WithFollowSymlinks(bool) FileLoader } // NewFileLoader returns a new FileLoader instance. @@ -114,14 +115,15 @@ func NewFileLoader() FileLoader { } type fileLoader struct { - metrics metrics.Metrics - filter Filter - bvc *bundle.VerificationConfig - skipVerify bool - files map[string]bundle.FileInfo - opts ast.ParserOptions - fsys fs.FS - reader io.Reader + metrics metrics.Metrics + filter Filter + bvc *bundle.VerificationConfig + skipVerify bool + files map[string]bundle.FileInfo + opts ast.ParserOptions + fsys fs.FS + reader io.Reader + followSymlinks bool } // WithFS provides an fs.FS to use for loading files. You can pass nil to @@ -188,6 +190,12 @@ func (fl *fileLoader) WithRegoVersion(version ast.RegoVersion) FileLoader { return fl } +// WithFollowSymlinks enables or disables following symlinks when loading files +func (fl *fileLoader) WithFollowSymlinks(followSymlinks bool) FileLoader { + fl.followSymlinks = followSymlinks + return fl +} + // All returns a Result object loaded (recursively) from the specified paths. func (fl fileLoader) All(paths []string) (*Result, error) { return fl.Filtered(paths, nil) @@ -249,6 +257,7 @@ func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) { if err != nil { return nil, err } + bundleLoader = bundleLoader.WithFollowSymlinks(fl.followSymlinks) br := bundle.NewCustomReader(bundleLoader). WithMetrics(fl.metrics). @@ -257,6 +266,7 @@ func (fl fileLoader) AsBundle(path string) (*bundle.Bundle, error) { WithProcessAnnotations(fl.opts.ProcessAnnotation). WithCapabilities(fl.opts.Capabilities). WithJSONOptions(fl.opts.JSONOptions). + WithFollowSymlinks(fl.followSymlinks). WithRegoVersion(fl.opts.RegoVersion) // For bundle directories add the full path in front of module file names diff --git a/runtime/runtime.go b/runtime/runtime.go index 899ea74a9c..456c5cefb4 100644 --- a/runtime/runtime.go +++ b/runtime/runtime.go @@ -343,7 +343,7 @@ func NewRuntime(ctx context.Context, params Params) (*Runtime, error) { } else { regoVersion = ast.RegoV0 } - loaded, err := initload.LoadPathsForRegoVersion(regoVersion, params.Paths, params.Filter, params.BundleMode, params.BundleVerificationConfig, params.SkipBundleVerification, false, nil, nil) + loaded, err := initload.LoadPathsForRegoVersion(regoVersion, params.Paths, params.Filter, params.BundleMode, params.BundleVerificationConfig, params.SkipBundleVerification, false, false, nil, nil) if err != nil { return nil, fmt.Errorf("load error: %w", err) }