diff --git a/examples/set/gotemplate_mySet.go b/examples/set/gotemplate_mySet.go index 98420ab..4b5a812 100644 --- a/examples/set/gotemplate_mySet.go +++ b/examples/set/gotemplate_mySet.go @@ -1,3 +1,5 @@ +// Code generated by gotemplate. DO NOT EDIT. + // Template Set type // // Tries to be similar to Python's set type diff --git a/examples/sort/gotemplate_Sort.go b/examples/sort/gen_Sort_gotemplate.go similarity index 99% rename from examples/sort/gotemplate_Sort.go rename to examples/sort/gen_Sort_gotemplate.go index 7eb92cf..667987d 100644 --- a/examples/sort/gotemplate_Sort.go +++ b/examples/sort/gen_Sort_gotemplate.go @@ -1,3 +1,5 @@ +// Code generated by gotemplate. DO NOT EDIT. + // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. diff --git a/examples/sort/gotemplate_SortF.go b/examples/sort/gotemplate_SortF.go index e821fbb..951539b 100644 --- a/examples/sort/gotemplate_SortF.go +++ b/examples/sort/gotemplate_SortF.go @@ -1,3 +1,5 @@ +// Code generated by gotemplate. DO NOT EDIT. + // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. diff --git a/examples/sort/gotemplate_SortGt.go b/examples/sort/gotemplate_SortGt.go index 7d3ba36..5a18c29 100644 --- a/examples/sort/gotemplate_SortGt.go +++ b/examples/sort/gotemplate_SortGt.go @@ -1,3 +1,5 @@ +// Code generated by gotemplate. DO NOT EDIT. + // Copyright 2009 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. diff --git a/examples/sort/main.go b/examples/sort/main.go index d168554..5defb71 100644 --- a/examples/sort/main.go +++ b/examples/sort/main.go @@ -8,7 +8,7 @@ import "fmt" // Regenerate the templates with "go generate" // Sort strings using the less function -//go:generate gotemplate "github.com/ncw/gotemplate/sort" "Sort(string, less)" +//go:generate gotemplate -outfmt gen_%v_gotemplate "github.com/ncw/gotemplate/sort" "Sort(string, less)" // Sort floats using the lt function //go:generate gotemplate "github.com/ncw/gotemplate/sort" "SortF(float64, lt)" diff --git a/main.go b/main.go index 98e5429..5a4f124 100644 --- a/main.go +++ b/main.go @@ -36,6 +36,7 @@ import ( var ( // Flags verbose = flag.Bool("v", false, "Verbose - print lots of stuff") + outfile = flag.String("outfmt", "gotemplate_%v", "the format of the output file") ) // Logging function @@ -67,8 +68,46 @@ func usage() { } func main() { + log.SetFlags(0) + log.SetPrefix("") + flag.Usage = usage flag.Parse() + + // do some basic validation on the outfile format + // indexing is safe because we're searching for an ascii char + { + found := 0 + c := *outfile + + for len(c) != 0 { + l := c[0] + c = c[1:] + + if l != '%' { + continue + } + + // end of string + if len(c) == 0 { + fatalf("invalid outfile format ending in %v", "%") + } + + switch c[0] { + case 'v': + found++ + default: + fatalf("outfile format contains invalid verb: %v", "%"+string(c[0])) + } + + c = c[1:] + } + + if found != 1 { + fatalf("could not find %v in outfile format", "%v") + } + } + args := flag.Args() if len(args) != 2 { fatalf("Need 2 arguments, package and parameters") diff --git a/template.go b/template.go index 92ff5c8..e5ea51f 100644 --- a/template.go +++ b/template.go @@ -4,17 +4,22 @@ package main import ( "bytes" + "fmt" "go/ast" "go/build" "go/format" "go/parser" "go/token" - "os" + "io/ioutil" "path" "regexp" "strings" ) +const ( + genHeader = "// Code generated by gotemplate. DO NOT EDIT.\n\n" +) + // Holds the desired template type template struct { Package string @@ -108,7 +113,7 @@ func parseTemplateAndArgs(s string) (name string, args []string) { // "template type Set(A)" var matchTemplateType = regexp.MustCompile(`^//\s*template\s+type\s+(\w+\s*.*?)\s*$`) -func (t *template) findTemplateDefinition(f *ast.File) { +func (t *template) findTemplateDefinition(f *ast.File) bool { // Inspect the comments t.templateName = "" t.templateArgs = nil @@ -124,37 +129,22 @@ func (t *template) findTemplateDefinition(f *ast.File) { } } if t.templateName == "" { - fatalf("Didn't find template definition in %s", t.inputFile) + return false } if len(t.templateArgs) != len(t.Args) { fatalf("Wrong number of arguments - template is expecting %d but %d supplied", len(t.Args), len(t.templateArgs)) } debugf("templateName = %v, templateArgs = %v", t.templateName, t.templateArgs) -} -// Ouput the go formatted file -// -// Exits with a fatal error on error -func outputFile(fset *token.FileSet, f *ast.File, path string) { - fd, err := os.Create(path) - if err != nil { - fatalf("Failed to open %q: %s", path, err) - } - if err := format.Node(fd, fset, f); err != nil { - fatalf("Failed to format %q: %s", path, err) - } - err = fd.Close() - if err != nil { - fatalf("Failed to close %q: %s", path, err) - } + return true } // Parses a file into a Fileset and Ast // // Dies with a fatal error on error -func parseFile(path string) (*token.FileSet, *ast.File) { +func parseFile(path string, src interface{}) (*token.FileSet, *ast.File) { fset := token.NewFileSet() // positions are relative to fset - f, err := parser.ParseFile(fset, path, nil, parser.ParseComments) + f, err := parser.ParseFile(fset, path, src, parser.ParseComments) if err != nil { fatalf("Failed to parse file: %s", err) } @@ -189,14 +179,9 @@ func (t *template) isTemplateArgument(name string) bool { } // Parses the template file -func (t *template) parse(inputFile string) { - t.inputFile = inputFile - // Make the name mappings +func (t *template) parse(fset *token.FileSet, f *ast.File) { t.newIsPublic = ast.IsExported(t.Name) - fset, f := parseFile(inputFile) - t.findTemplateDefinition(f) - // debugf("Decls = %#v", f.Decls) // Find names which need to be adjusted namesToMangle := []string{} @@ -311,19 +296,48 @@ func (t *template) parse(inputFile string) { // Change the package to the local package name f.Name.Name = t.NewPackage - // Output - outputFileName := "gotemplate_" + t.Name + ".go" - outputFile(fset, f, outputFileName) + // Output but only if contents have changed from existing file + + b := bytes.NewBuffer(nil) + outputFileName := fmt.Sprintf(*outfile+".go", t.Name) + + format := func() { + b.Reset() + if err := format.Node(b, fset, f); err != nil { + fatalf("Failed to format output: %s", err) + } + } + + format() - // gofmt one last time to sort out messy identifier substution - fset, f = parseFile(outputFileName) - outputFile(fset, f, outputFileName) - logf("Written '%s'", outputFileName) + // bit gross to inject the header this way... but in the spirit of + // minimal changes et al... + fset, f = parseFile(outputFileName, genHeader+b.String()) + + format() + + write := true + + curr, err := ioutil.ReadFile(outputFileName) + if err == nil { + if bytes.Equal(curr, b.Bytes()) { + write = false + } + } + + if write { + err := ioutil.WriteFile(outputFileName, b.Bytes(), 0666) + if err != nil { + fatalf("unable to write to %q: %v", outputFileName, err) + } + } + + debugf("Written '%s'", outputFileName) } // Instantiate the template package func (t *template) instantiate() { - logf("Substituting %q with %s(%s) into package %s", t.Package, t.Name, strings.Join(t.Args, ","), t.NewPackage) + debugf("Substituting %q with %s(%s) into package %s", t.Package, t.Name, strings.Join(t.Args, ","), t.NewPackage) p, err := build.Default.Import(t.Package, t.Dir, build.ImportMode(0)) if err != nil { @@ -334,14 +348,27 @@ func (t *template) instantiate() { // FIXME CgoFiles ? debugf("Go files = %#v", p.GoFiles) - if len(p.GoFiles) == 0 { - fatalf("No go files found for package '%s'", t.Package) + count := 0 + + var fset *token.FileSet + var file *ast.File + + for _, f := range p.GoFiles { + templateFilePath := path.Join(p.Dir, f) + fset, file = parseFile(templateFilePath, nil) + if v := t.findTemplateDefinition(file); v { + t.inputFile = templateFilePath + count++ + } } - // FIXME - if len(p.GoFiles) != 1 { - fatalf("Found more than one go file in '%s' - can only cope with 1 for the moment, sorry", t.Package) + + if count == 0 { + fatalf("Failed to find a .go file with a template definition in %v", t.Package) + } + + if count > 1 { + fatalf("Found more than template definition %s - can only cope with 1 for the moment, sorry", t.Package) } - templateFilePath := path.Join(p.Dir, p.GoFiles[0]) - t.parse(templateFilePath) + t.parse(fset, file) } diff --git a/template_test.go b/template_test.go index 5c54887..064afad 100644 --- a/template_test.go +++ b/template_test.go @@ -51,7 +51,9 @@ var tests = []TestTemplate{ pkg: "main", in: basicTest, outName: "gotemplate_MySet.go", - out: `package main + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main import "fmt" @@ -80,7 +82,9 @@ var ( pkg: "main", in: basicTest, outName: "gotemplate_mySet.go", - out: `package main + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main import "fmt" @@ -117,7 +121,9 @@ func TT(a, b A) A { return Less(a, b) } func TTone(a A) A { return !Less(a, b) } `, outName: "gotemplate_Min.go", - out: `package main + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main // template type TT(A, Less) @@ -152,7 +158,9 @@ func (v Vector) Add(b Vector) { } `, outName: "gotemplate_Vector2.go", - out: `package main + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main // template type Vector(A, n) @@ -189,7 +197,9 @@ func (mat Matrix) Add(x Matrix) { } `, outName: "gotemplate_Matrix22.go", - out: `package main + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main // template type Matrix(A, n, m) @@ -244,7 +254,9 @@ var ( func Prog() int {return a+b+c+d+e+f} `, outName: "gotemplate_ProgXX.go", - out: `package main + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main // template type Prog(a, b, c, d, e, f) type AProgXX float32 @@ -308,7 +320,9 @@ type ( ) `, outName: "gotemplate_tmpl.go", - out: `package main + out: `// Code generated by gotemplate. DO NOT EDIT. + +package main // template type TMPL(A, B, C, D, E, F)