You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
boundary/api/internal/genapi/genapi.go

467 lines
11 KiB

package main
// +build genapi
import (
"bufio"
"bytes"
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strings"
"text/template"
"time"
)
type templateType int
const (
templateTypeResource templateType = iota
templateTypeDetail
)
type generateInfo struct {
inFile string
inName string
outFile string
outName string
outPkg string
structFields string
parentName string
detailName string
templateType templateType
}
var (
regex = regexp.MustCompile(`(json:".*")`)
inputStructs = []*generateInfo{
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host.pb.go",
"Host",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host.go",
"Host",
"hosts",
"",
"",
"",
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_set.pb.go",
"HostSet",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host_set.go",
"HostSet",
"hosts",
"",
"",
"",
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
"HostCatalog",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/host_catalog.go",
"HostCatalog",
"hosts",
"",
"",
"",
templateTypeResource,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
"StaticHostCatalogDetails",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/static_host_catalog.go",
"StaticHostCatalogDetails",
"hosts",
"",
"HostCatalog",
"StaticHostCatalog",
templateTypeDetail,
},
{
os.Getenv("APIGEN_BASEPATH") + "/internal/gen/controller/api/resources/hosts/host_catalog.pb.go",
"AwsEc2HostCatalogDetails",
os.Getenv("APIGEN_BASEPATH") + "/api/hosts/awsec2_host_catalog.go",
"AwsEc2HostCatalogDetails",
"hosts",
"",
"HostCatalog",
"AwsEc2HostCatalog",
templateTypeDetail,
},
}
)
type visitFn func(node ast.Node)
func (fn visitFn) Visit(node ast.Node) ast.Visitor {
fn(node)
return fn
}
func main() {
createStructs()
createUtilFuncs()
}
func createStructs() {
for _, inputStruct := range inputStructs {
inFile, err := filepath.Abs(inputStruct.inFile)
if err != nil {
fmt.Printf("error opening file %q: %v\n", inputStruct.inFile, err)
os.Exit(1)
}
fset := token.NewFileSet()
inAst, err := parser.ParseFile(fset, inFile, nil, parser.ParseComments)
if err != nil {
fmt.Printf("error parsing %s: %v\n", inFile, err)
os.Exit(1)
}
ast.Walk(visitFn(func(n ast.Node) {
spec, ok := n.(*ast.TypeSpec)
if !ok {
return
}
if spec.Name == nil {
return
}
if spec.Name.Name != inputStruct.inName {
return
}
spec.Name.Name = inputStruct.outName
st, ok := spec.Type.(*ast.StructType)
if !ok {
fmt.Printf("expected struct type for identifier, got %t\n", spec.Type)
os.Exit(1)
return
}
if st.Fields.List == nil {
fmt.Printf("no fields found in %q\n", inputStruct.inName)
os.Exit(1)
return
}
// Tee up unexported field deletion
var elideList []int
{
defer func() {
var cutCount int
// Remove unexported proto stuff
for _, val := range elideList {
loc := val - cutCount
st.Fields.List = append(st.Fields.List[:loc], st.Fields.List[loc+1:]...)
cutCount++
}
// Add default fields if a base resource
if inputStruct.templateType == templateTypeResource {
st.Fields.List = append([]*ast.Field{&ast.Field{
Names: []*ast.Ident{
{
Name: "defaultFields",
Obj: &ast.Object{
Kind: ast.Var,
Name: "defaultFields",
},
},
},
Type: &ast.ArrayType{
Elt: &ast.Ident{
Name: "string",
},
},
}}, st.Fields.List...)
}
}()
}
for i, field := range st.Fields.List {
if !field.Names[0].IsExported() {
elideList = append(elideList, i)
continue
}
// Anything that isn't a basic type we expect to be a star
// expression with a selector; that is, a wrapper value, timestamp
// value, etc.
//
// TODO: this isn't necessarily a good assumption, which means we
// might get failures with other types. This is an internal tools
// only; we can revisit as needed!
var selectorExpr *ast.SelectorExpr
switch typ := field.Type.(type) {
case *ast.Ident:
typ.Name = "*" + typ.Name
goto TAGMODIFY
case *ast.ArrayType:
goto TAGMODIFY
case *ast.StarExpr:
switch nextTyp := typ.X.(type) {
case *ast.Ident:
// Already a pointer, don't do anything
goto TAGMODIFY
case *ast.SelectorExpr:
selectorExpr = nextTyp
}
case *ast.SelectorExpr:
selectorExpr = typ
}
switch {
case selectorExpr != nil:
xident, ok := selectorExpr.X.(*ast.Ident)
if !ok {
fmt.Printf("unexpected non-ident type in selector\n")
os.Exit(1)
}
switch xident.Name {
case "wrappers":
switch selectorExpr.Sel.Name {
case "StringValue":
st.Fields.List[i] = &ast.Field{
Names: field.Names,
Type: &ast.Ident{
Name: "*string",
},
Tag: field.Tag,
}
case "BoolValue":
st.Fields.List[i] = &ast.Field{
Names: field.Names,
Type: &ast.Ident{
Name: "*bool",
},
Tag: field.Tag,
}
case "Int64Value":
st.Fields.List[i] = &ast.Field{
Names: field.Names,
Type: &ast.Ident{
Name: "*int64",
},
Tag: field.Tag,
}
default:
fmt.Printf("unhandled wrappers selector sel name %q\n", selectorExpr.Sel.Name)
os.Exit(1)
}
case "timestamp":
switch selectorExpr.Sel.Name {
case "Timestamp":
st.Fields.List[i] = &ast.Field{
Names: field.Names,
Type: &ast.Ident{
Name: "time.Time",
},
Tag: field.Tag,
}
default:
fmt.Printf("unhandled timestamp selector sel name %q\n", selectorExpr.Sel.Name)
os.Exit(1)
}
case "_struct":
switch selectorExpr.Sel.Name {
case "Struct":
st.Fields.List[i] = &ast.Field{
Names: field.Names,
Type: &ast.MapType{
Key: &ast.Ident{
Name: "string",
},
Value: &ast.InterfaceType{
Methods: &ast.FieldList{},
},
},
Tag: field.Tag,
}
default:
fmt.Printf("unhandled timestamp selector sel name %q\n", selectorExpr.Sel.Name)
os.Exit(1)
}
default:
fmt.Printf("unhandled xident name %q\n", xident.Name)
os.Exit(1)
}
default:
fmt.Println("unhandled non-ident, non-selector case")
os.Exit(1)
}
TAGMODIFY:
st.Fields.List[i].Tag.Value = "`" + regex.FindString(st.Fields.List[i].Tag.Value) + "`"
}
}), inAst)
buf := new(bytes.Buffer)
if err := format.Node(buf, fset, inAst); err != nil {
fmt.Printf("error formatting new code: %v\n", err)
os.Exit(1)
}
// We have to manually cut out the lines we don't want because comments
// won't be preserved otherwise. See the note about lossy comments in
// https://arslan.io/2017/09/14/the-ultimate-guide-to-writing-a-go-tool/
scanner := bufio.NewScanner(bytes.NewBufferString(buf.String()))
var inType bool
var outBuf []string
for scanner.Scan() {
if !inType {
if strings.HasPrefix(scanner.Text(), "type "+inputStruct.outName) {
inType = true
// Don't add this line, we'll do it in the template
continue
}
} else {
if scanner.Text() == "}" {
// We've reached the end of the type
break
}
}
if inType {
outBuf = append(outBuf, scanner.Text())
}
}
inputStruct.structFields = strings.Join(outBuf, "\n")
}
}
func createUtilFuncs() {
for _, inputStruct := range inputStructs {
outBuf := new(bytes.Buffer)
switch inputStruct.templateType {
case templateTypeResource:
utilFuncsTemplate.Execute(outBuf, struct {
Timestamp time.Time
Name string
Package string
StructFields string
}{
Timestamp: time.Now(),
Name: inputStruct.outName,
Package: inputStruct.outPkg,
StructFields: inputStruct.structFields,
})
case templateTypeDetail:
detailTemplate.Execute(outBuf, struct {
Timestamp time.Time
Package string
StructFields string
ParentName string
DetailName string
}{
Timestamp: time.Now(),
Package: inputStruct.outPkg,
StructFields: inputStruct.structFields,
ParentName: inputStruct.parentName,
DetailName: inputStruct.detailName,
})
}
outFile, err := filepath.Abs(inputStruct.outFile)
if err != nil {
fmt.Printf("error opening file %q: %v\n", inputStruct.outFile, err)
os.Exit(1)
}
if err := ioutil.WriteFile(outFile, outBuf.Bytes(), 0644); err != nil {
fmt.Printf("error writing file %q: %v\n", outFile, err)
os.Exit(1)
}
}
}
var utilFuncsTemplate = template.Must(template.New("").Parse(`// Code generated by go generate; DO NOT EDIT.
// This file was generated by robots at
// {{ .Timestamp }}
package {{ .Package }}
import (
"encoding/json"
"github.com/fatih/structs"
"github.com/hashicorp/watchtower/api/internal/strutil"
)
type {{ .Name }} struct {
{{ .StructFields }}
}
func (s *{{ .Name }}) SetDefault(key string) {
s.defaultFields = strutil.AppendIfMissing(s.defaultFields, key)
}
func (s *{{ .Name }}) UnsetDefault(key string) {
s.defaultFields = strutil.StrListDelete(s.defaultFields, key)
}
func (s {{ .Name }}) MarshalJSON() ([]byte, error) {
m := structs.Map(s)
if m == nil {
m = make(map[string]interface{})
}
for _, k := range s.defaultFields {
m[k] = nil
}
return json.Marshal(m)
}
`))
// TODO: Figure out the right way to write out the specific fields
var detailTemplate = template.Must(template.New("").Parse(`// Code generated by go generate; DO NOT EDIT.
// This file was generated by robots at
// {{ .Timestamp }}
package {{ .Package }}
import (
"fmt"
"github.com/mitchellh/mapstructure"
)
type {{ .DetailName }} struct {
*{{ .ParentName }}
{{ .StructFields }}
}
func (s {{ .ParentName }}) As{{ .DetailName }}() (*{{ .DetailName }}, error) {
out := &{{ .DetailName }}{
{{ .ParentName }}: &s,
}
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
Result: out,
TagName: "json",
})
if err != nil {
return nil, fmt.Errorf("error creating map decoder: %w", err)
}
if err := decoder.Decode(s.Attributes); err != nil {
return nil, fmt.Errorf("error decoding attributes map: %w", err)
}
return out, nil
}
`))