diff --git a/post-processor/vagrant/post-processor.go b/post-processor/vagrant/post-processor.go index 004365d27..ea2144263 100644 --- a/post-processor/vagrant/post-processor.go +++ b/post-processor/vagrant/post-processor.go @@ -56,44 +56,27 @@ type Config struct { } type PostProcessor struct { - configs map[string]*Config + config Config } func (p *PostProcessor) ConfigSpec() hcldec.ObjectSpec { - panic("not implemented yet") - // return p.config.FlatMapstructure().HCL2Spec() + return p.config.FlatMapstructure().HCL2Spec() } func (p *PostProcessor) Configure(raws ...interface{}) error { - p.configs = make(map[string]*Config) - p.configs[""] = new(Config) - if err := p.configureSingle(p.configs[""], raws...); err != nil { + if err := p.configureSingle(&p.config, raws...); err != nil { return err } - - // Go over any of the provider-specific overrides and load those up. - for name, override := range p.configs[""].Override { - subRaws := make([]interface{}, len(raws)+1) - copy(subRaws, raws) - subRaws[len(raws)] = override - - config := new(Config) - p.configs[name] = config - if err := p.configureSingle(config, subRaws...); err != nil { - return fmt.Errorf("Error configuring %s: %s", name, err) - } - } - return nil } func (p *PostProcessor) PostProcessProvider(name string, provider Provider, ui packer.Ui, artifact packer.Artifact) (packer.Artifact, bool, error) { - config := p.configs[""] - if specificConfig, ok := p.configs[name]; ok { - config = specificConfig + config, err := p.specificConfig(name) + if err != nil { + return nil, false, err } - err := CreateDummyBox(ui, config.CompressionLevel) + err = CreateDummyBox(ui, config.CompressionLevel) if err != nil { return nil, false, err } @@ -246,6 +229,17 @@ func (p *PostProcessor) configureSingle(c *Config, raws ...interface{}) error { return nil } +func (p *PostProcessor) specificConfig(name string) (Config, error) { + config := p.config + if _, ok := config.Override[name]; ok { + if err := mapstructure.Decode(config.Override[name], &config); err != nil { + err = fmt.Errorf("Error overriding config for %s: %s", name, err) + return config, err + } + } + return config, nil +} + func providerForName(name string) Provider { switch name { case "aws": diff --git a/post-processor/vagrant/post-processor_test.go b/post-processor/vagrant/post-processor_test.go index fccf4cf24..8d414bab2 100644 --- a/post-processor/vagrant/post-processor_test.go +++ b/post-processor/vagrant/post-processor_test.go @@ -46,7 +46,7 @@ func TestPostProcessorPrepare_compressionLevel(t *testing.T) { t.Fatalf("err: %s", err) } - config := p.configs[""] + config := p.config if config.CompressionLevel != flate.DefaultCompression { t.Fatalf("bad: %#v", config.CompressionLevel) } @@ -58,7 +58,7 @@ func TestPostProcessorPrepare_compressionLevel(t *testing.T) { t.Fatalf("err: %s", err) } - config = p.configs[""] + config = p.config if config.CompressionLevel != 7 { t.Fatalf("bad: %#v", config.CompressionLevel) } @@ -83,43 +83,48 @@ func TestPostProcessorPrepare_outputPath(t *testing.T) { } } -func TestPostProcessorPrepare_subConfigs(t *testing.T) { +func TestSpecificConfig(t *testing.T) { var p PostProcessor - f, err := ioutil.TempFile("", "packer") - if err != nil { - t.Fatalf("err: %s", err) - } - defer os.Remove(f.Name()) - // Default c := testConfig() - c["compression_level"] = 42 - c["vagrantfile_template"] = f.Name() + c["compression_level"] = 1 + c["output"] = "folder" c["override"] = map[string]interface{}{ "aws": map[string]interface{}{ "compression_level": 7, }, } - err = p.Configure(c) + if err := p.Configure(c); err != nil { + t.Fatalf("err: %s", err) + } + + // overrides config + config, err := p.specificConfig("aws") if err != nil { t.Fatalf("err: %s", err) } - if p.configs[""].CompressionLevel != 42 { - t.Fatalf("bad: %#v", p.configs[""].CompressionLevel) + if config.CompressionLevel != 7 { + t.Fatalf("bad: %#v", config.CompressionLevel) } - if p.configs[""].VagrantfileTemplate != f.Name() { - t.Fatalf("bad: %#v", p.configs[""].VagrantfileTemplate) + if config.OutputPath != "folder" { + t.Fatalf("bad: %#v", config.OutputPath) } - if p.configs["aws"].CompressionLevel != 7 { - t.Fatalf("bad: %#v", p.configs["aws"].CompressionLevel) + // does NOT overrides config + config, err = p.specificConfig("virtualbox") + if err != nil { + t.Fatalf("err: %s", err) + } + + if config.CompressionLevel != 1 { + t.Fatalf("bad: %#v", config.CompressionLevel) } - if p.configs["aws"].VagrantfileTemplate != f.Name() { - t.Fatalf("bad: %#v", p.configs["aws"].VagrantfileTemplate) + if config.OutputPath != "folder" { + t.Fatalf("bad: %#v", config.OutputPath) } }