Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions examples/set/gotemplate_mySet.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions examples/sort/gotemplate_SortF.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions examples/sort/gotemplate_SortGt.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion examples/sort/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import "fmt"
// Regenerate the templates with "go generate"

// Sort strings using the less function
//go:generate gotemplate "github.com/ncw/gotemplate/sort" "Sort(string, less)"
//go:generate gotemplate -outfmt gen_%v_gotemplate "github.com/ncw/gotemplate/sort" "Sort(string, less)"

// Sort floats using the lt function
//go:generate gotemplate "github.com/ncw/gotemplate/sort" "SortF(float64, lt)"
Expand Down
39 changes: 39 additions & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ import (
var (
// Flags
verbose = flag.Bool("v", false, "Verbose - print lots of stuff")
outfile = flag.String("outfmt", "gotemplate_%v", "the format of the output file")
)

// Logging function
Expand Down Expand Up @@ -67,8 +68,46 @@ func usage() {
}

func main() {
log.SetFlags(0)
log.SetPrefix("")

flag.Usage = usage
flag.Parse()

// do some basic validation on the outfile format
// indexing is safe because we're searching for an ascii char
{
found := 0
c := *outfile

for len(c) != 0 {
l := c[0]
c = c[1:]

if l != '%' {
continue
}

// end of string
if len(c) == 0 {
fatalf("invalid outfile format ending in %v", "%")
}

switch c[0] {
case 'v':
found++
default:
fatalf("outfile format contains invalid verb: %v", "%"+string(c[0]))
}

c = c[1:]
}

if found != 1 {
fatalf("could not find %v in outfile format", "%v")
}
}

args := flag.Args()
if len(args) != 2 {
fatalf("Need 2 arguments, package and parameters")
Expand Down
111 changes: 69 additions & 42 deletions template.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@ package main

import (
"bytes"
"fmt"
"go/ast"
"go/build"
"go/format"
"go/parser"
"go/token"
"os"
"io/ioutil"
"path"
"regexp"
"strings"
)

const (
genHeader = "// Code generated by gotemplate. DO NOT EDIT.\n\n"
)

// Holds the desired template
type template struct {
Package string
Expand Down Expand Up @@ -108,7 +113,7 @@ func parseTemplateAndArgs(s string) (name string, args []string) {
// "template type Set(A)"
var matchTemplateType = regexp.MustCompile(`^//\s*template\s+type\s+(\w+\s*.*?)\s*$`)

func (t *template) findTemplateDefinition(f *ast.File) {
func (t *template) findTemplateDefinition(f *ast.File) bool {
// Inspect the comments
t.templateName = ""
t.templateArgs = nil
Expand All @@ -124,37 +129,22 @@ func (t *template) findTemplateDefinition(f *ast.File) {
}
}
if t.templateName == "" {
fatalf("Didn't find template definition in %s", t.inputFile)
return false
}
if len(t.templateArgs) != len(t.Args) {
fatalf("Wrong number of arguments - template is expecting %d but %d supplied", len(t.Args), len(t.templateArgs))
}
debugf("templateName = %v, templateArgs = %v", t.templateName, t.templateArgs)
}

// Ouput the go formatted file
//
// Exits with a fatal error on error
func outputFile(fset *token.FileSet, f *ast.File, path string) {
fd, err := os.Create(path)
if err != nil {
fatalf("Failed to open %q: %s", path, err)
}
if err := format.Node(fd, fset, f); err != nil {
fatalf("Failed to format %q: %s", path, err)
}
err = fd.Close()
if err != nil {
fatalf("Failed to close %q: %s", path, err)
}
return true
}

// Parses a file into a Fileset and Ast
//
// Dies with a fatal error on error
func parseFile(path string) (*token.FileSet, *ast.File) {
func parseFile(path string, src interface{}) (*token.FileSet, *ast.File) {
fset := token.NewFileSet() // positions are relative to fset
f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
f, err := parser.ParseFile(fset, path, src, parser.ParseComments)
if err != nil {
fatalf("Failed to parse file: %s", err)
}
Expand Down Expand Up @@ -189,14 +179,9 @@ func (t *template) isTemplateArgument(name string) bool {
}

// Parses the template file
func (t *template) parse(inputFile string) {
t.inputFile = inputFile
// Make the name mappings
func (t *template) parse(fset *token.FileSet, f *ast.File) {
t.newIsPublic = ast.IsExported(t.Name)

fset, f := parseFile(inputFile)
t.findTemplateDefinition(f)

// debugf("Decls = %#v", f.Decls)
// Find names which need to be adjusted
namesToMangle := []string{}
Expand Down Expand Up @@ -311,19 +296,48 @@ func (t *template) parse(inputFile string) {
// Change the package to the local package name
f.Name.Name = t.NewPackage

// Output
outputFileName := "gotemplate_" + t.Name + ".go"
outputFile(fset, f, outputFileName)
// Output but only if contents have changed from existing file

b := bytes.NewBuffer(nil)
outputFileName := fmt.Sprintf(*outfile+".go", t.Name)

format := func() {
b.Reset()
if err := format.Node(b, fset, f); err != nil {
fatalf("Failed to format output: %s", err)
}
}

format()

// gofmt one last time to sort out messy identifier substution
fset, f = parseFile(outputFileName)
outputFile(fset, f, outputFileName)
logf("Written '%s'", outputFileName)
// bit gross to inject the header this way... but in the spirit of
// minimal changes et al...
fset, f = parseFile(outputFileName, genHeader+b.String())

format()

write := true

curr, err := ioutil.ReadFile(outputFileName)
if err == nil {
if bytes.Equal(curr, b.Bytes()) {
write = false
}
}

if write {
err := ioutil.WriteFile(outputFileName, b.Bytes(), 0666)
if err != nil {
fatalf("unable to write to %q: %v", outputFileName, err)
}
}

debugf("Written '%s'", outputFileName)
}

// Instantiate the template package
func (t *template) instantiate() {
logf("Substituting %q with %s(%s) into package %s", t.Package, t.Name, strings.Join(t.Args, ","), t.NewPackage)
debugf("Substituting %q with %s(%s) into package %s", t.Package, t.Name, strings.Join(t.Args, ","), t.NewPackage)

p, err := build.Default.Import(t.Package, t.Dir, build.ImportMode(0))
if err != nil {
Expand All @@ -334,14 +348,27 @@ func (t *template) instantiate() {
// FIXME CgoFiles ?
debugf("Go files = %#v", p.GoFiles)

if len(p.GoFiles) == 0 {
fatalf("No go files found for package '%s'", t.Package)
count := 0

var fset *token.FileSet
var file *ast.File

for _, f := range p.GoFiles {
templateFilePath := path.Join(p.Dir, f)
fset, file = parseFile(templateFilePath, nil)
if v := t.findTemplateDefinition(file); v {
t.inputFile = templateFilePath
count++
}
}
// FIXME
if len(p.GoFiles) != 1 {
fatalf("Found more than one go file in '%s' - can only cope with 1 for the moment, sorry", t.Package)

if count == 0 {
fatalf("Failed to find a .go file with a template definition in %v", t.Package)
}

if count > 1 {
fatalf("Found more than template definition %s - can only cope with 1 for the moment, sorry", t.Package)
}

templateFilePath := path.Join(p.Dir, p.GoFiles[0])
t.parse(templateFilePath)
t.parse(fset, file)
}
28 changes: 21 additions & 7 deletions template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ var tests = []TestTemplate{
pkg: "main",
in: basicTest,
outName: "gotemplate_MySet.go",
out: `package main
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

import "fmt"

Expand Down Expand Up @@ -80,7 +82,9 @@ var (
pkg: "main",
in: basicTest,
outName: "gotemplate_mySet.go",
out: `package main
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

import "fmt"

Expand Down Expand Up @@ -117,7 +121,9 @@ func TT(a, b A) A { return Less(a, b) }
func TTone(a A) A { return !Less(a, b) }
`,
outName: "gotemplate_Min.go",
out: `package main
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

// template type TT(A, Less)

Expand Down Expand Up @@ -152,7 +158,9 @@ func (v Vector) Add(b Vector) {
}
`,
outName: "gotemplate_Vector2.go",
out: `package main
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

// template type Vector(A, n)

Expand Down Expand Up @@ -189,7 +197,9 @@ func (mat Matrix) Add(x Matrix) {
}
`,
outName: "gotemplate_Matrix22.go",
out: `package main
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

// template type Matrix(A, n, m)

Expand Down Expand Up @@ -244,7 +254,9 @@ var (
func Prog() int {return a+b+c+d+e+f}
`,
outName: "gotemplate_ProgXX.go",
out: `package main
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

// template type Prog(a, b, c, d, e, f)
type AProgXX float32
Expand Down Expand Up @@ -308,7 +320,9 @@ type (
)
`,
outName: "gotemplate_tmpl.go",
out: `package main
out: `// Code generated by gotemplate. DO NOT EDIT.

package main

// template type TMPL(A, B, C, D, E, F)

Expand Down