package main import ( "go/ast" ) type ( parseVisitor struct { src *sourceContext } typeSpecVisitor struct { src *sourceContext node *ast.TypeSpec iface *iface name *ast.Ident } interfaceTypeVisitor struct { node *ast.TypeSpec ts *typeSpecVisitor methods []method } methodVisitor struct { depth int node *ast.TypeSpec list *[]method name *ast.Ident params, results *[]arg isMethod bool } argListVisitor struct { list *[]arg } argVisitor struct { node *ast.TypeSpec parts []ast.Expr list *[]arg } ) func (v *parseVisitor) Visit(n ast.Node) ast.Visitor { switch rn := n.(type) { default: return v case *ast.File: v.src.pkg = rn.Name return v case *ast.ImportSpec: v.src.imports = append(v.src.imports, rn) return nil case *ast.TypeSpec: switch rn.Type.(type) { default: v.src.types = append(v.src.types, rn) case *ast.InterfaceType: // can't output interfaces // because they'd conflict with our implementations } return &typeSpecVisitor{src: v.src, node: rn} } } /* package foo type FooService interface { Bar(ctx context.Context, i int, s string) (string, error) } */ func (v *typeSpecVisitor) Visit(n ast.Node) ast.Visitor { switch rn := n.(type) { default: return v case *ast.Ident: if v.name == nil { v.name = rn } return v case *ast.InterfaceType: return &interfaceTypeVisitor{ts: v, methods: []method{}} case nil: if v.iface != nil { v.iface.name = v.name sn := *v.name v.iface.stubname = &sn v.iface.stubname.Name = v.name.String() v.src.interfaces = append(v.src.interfaces, *v.iface) } return nil } } func (v *interfaceTypeVisitor) Visit(n ast.Node) ast.Visitor { switch n.(type) { default: return v case *ast.Field: return &methodVisitor{list: &v.methods} case nil: v.ts.iface = &iface{methods: v.methods} return nil } } func (v *methodVisitor) Visit(n ast.Node) ast.Visitor { switch rn := n.(type) { default: v.depth++ return v case *ast.Ident: if rn.IsExported() { v.name = rn } v.depth++ return v case *ast.FuncType: v.depth++ v.isMethod = true return v case *ast.FieldList: if v.params == nil { v.params = &[]arg{} return &argListVisitor{list: v.params} } if v.results == nil { v.results = &[]arg{} } return &argListVisitor{list: v.results} case nil: v.depth-- if v.depth == 0 && v.isMethod && v.name != nil { *v.list = append(*v.list, method{name: v.name, params: *v.params, results: *v.results}) } return nil } } func (v *argListVisitor) Visit(n ast.Node) ast.Visitor { switch n.(type) { default: return nil case *ast.Field: return &argVisitor{list: v.list} } } func (v *argVisitor) Visit(n ast.Node) ast.Visitor { switch t := n.(type) { case *ast.CommentGroup, *ast.BasicLit: return nil case *ast.Ident: //Expr -> everything, but clarity if t.Name != "_" { v.parts = append(v.parts, t) } case ast.Expr: v.parts = append(v.parts, t) case nil: names := v.parts[:len(v.parts)-1] tp := v.parts[len(v.parts)-1] if len(names) == 0 { *v.list = append(*v.list, arg{typ: tp}) return nil } for _, n := range names { *v.list = append(*v.list, arg{ name: n.(*ast.Ident), typ: tp, }) } } return nil }