《GO语言高级编程》设计中案例,仅作为笔记进行收藏。Protobuf的protoc编译器是通过插件机制实现对不同语⾔的⽀持。⽐如protoc命令出现 --xxx_out 格式的参数,那么protoc将⾸先查询是否有内置的xxx插件,如果没有内置的xxx插件那么将继续查询当前系统中是否存在protoc-gen-xxx命名的可执⾏程序,最终通过查询到的插件⽣成代码。对于Go语⾔的protoc-gen-go插件来说,⾥⾯⼜实现了⼀层静态插件系统。⽐如protoc-gen-go内置了⼀个gRPC插件,⽤户可以通过 --go_out=plugins=grpc 参数来⽣成gRPC相关代码,否则只会针对message⽣成相关代码。
参考gRPC插件的代码,设计一个netrpcPlugin插件,⽤于为标准库的RPC框架⽣成代码。
go mod init protoc-gen-go
package netrpc
import (
"bytes"
"log"
"text/template"
"github.com/golang/protobuf/protoc-gen-go/descriptor"
"github.com/golang/protobuf/protoc-gen-go/generator"
)
func init() {
generator.RegisterPlugin(new(netrpcPlugin))
}
type netrpcPlugin struct{ *generator.Generator }
func (p *netrpcPlugin) Name() string { return "netrpc" }
func (p *netrpcPlugin) Init(g *generator.Generator) { p.Generator = g }
func (p *netrpcPlugin) GenerateImports(file *generator.FileDescriptor) {
if len(file.Service) > 0 {
p.genImportCode(file)
}
}
func (p *netrpcPlugin) Generate(file *generator.FileDescriptor) {
for _, svc := range file.Service {
p.genServiceCode(svc)
}
}
type ServiceSpec struct {
ServiceName string
MethodList []ServiceMethodSpec
}
type ServiceMethodSpec struct {
MethodName string
InputTypeName string
OutputTypeName string
}
func (p *netrpcPlugin) genImportCode(file *generator.FileDescriptor) {
p.P(`import "net/rpc"`)
}
func (p *netrpcPlugin) genServiceCode(svc *descriptor.ServiceDescriptorProto) {
spec := p.buildServiceSpec(svc)
var buf bytes.Buffer
t := template.Must(template.New("").Parse(tmplService))
err := t.Execute(&buf, spec)
if err != nil {
log.Fatal(err)
}
p.P(buf.String())
}
func (p *netrpcPlugin) buildServiceSpec(svc *descriptor.ServiceDescriptorProto) *ServiceSpec {
spec := &ServiceSpec{
ServiceName: generator.CamelCase(svc.GetName()),
}
for _, m := range svc.Method {
spec.MethodList = append(spec.MethodList, ServiceMethodSpec{
MethodName: generator.CamelCase(m.GetName()),
InputTypeName: p.TypeName(p.ObjectNamed(m.GetInputType())),
OutputTypeName: p.TypeName(p.ObjectNamed(m.GetOutputType())),
})
}
return spec
}
const tmplService = `
{
{$root := .}}
type {
{.ServiceName}}Interface interface {
{
{- range $_, $m := .MethodList}}
{
{$m.MethodName}}(in *{
{$m.InputTypeName}}, out *{
{$m.OutputTypeName}}) error
{
{- end}}
}
func Register{
{.ServiceName}}(srv *rpc.Server, x {
{.ServiceName}}Interface) error {
if err := srv.RegisterName("{
{.ServiceName}}", x); err != nil {
return err
}
return nil
}
type {
{.ServiceName}}Client struct {
*rpc.Client
}
var _ {
{.ServiceName}}Interface = (*{
{.ServiceName}}Client)(nil)
func Dial{
{.ServiceName}}(network, address string) (*{
{.ServiceName}}Client, error) {
c, err := rpc.Dial(network, address)
if err != nil {
return nil, err
}
return &{
{.ServiceName}}Client{Client: c}, nil
}
{
{range $_, $m := .MethodList}}
func (p *{
{$root.ServiceName}}Client) {
{$m.MethodName}}(in *{
{$m.InputTypeName}}, out *{
{$m.OutputTypeName}}) error {
return p.Client.Call("{
{$root.ServiceName}}.{
{$m.MethodName}}", in, out)
}
{
{end}}
`
package main
import (
"io/ioutil"
"os"
"github.com/golang/protobuf/proto"
"github.com/golang/protobuf/protoc-gen-go/generator"
_ "protoc-gen-go/netrpc"
)
func main() {
g := generator.New()
data, err := ioutil.ReadAll(os.Stdin)
if err != nil {
g.Error(err, "reading input")
}
if err := proto.Unmarshal(data, g.Request); err != nil {
g.Error(err, "parsing input proto")
}
if len(g.Request.FileToGenerate) == 0 {
g.Fail("no files to generate")
}
g.CommandLineParameters(g.Request.GetParameter())
g.WrapTypes()
g.SetPackageNames()
g.BuildTypeNameMap()
g.GenerateAllFiles()
data, err = proto.Marshal(g.Response)
if err != nil {
g.Error(err, "failed to marshal output proto")
}
_, err = os.Stdout.Write(data)
if err != nil {
g.Error(err, "failed to write output proto")
}
}
go build .
syntax = "proto3";
package hello;
message String{
string value=1;
}
service HelloService{
rpc Hello (String) returns (String);
}
protoc --go_out=plugins=netrpc:. hello.proto
// Code generated by protoc-gen-go. DO NOT EDIT.
// source: hello.proto
package hello
import (
fmt "fmt"
proto "github.com/golang/protobuf/proto"
math "math"
)
import "net/rpc"
// Reference imports to suppress errors if they are not otherwise used.
var _ = proto.Marshal
var _ = fmt.Errorf
var _ = math.Inf
// This is a compile-time assertion to ensure that this generated file
// is compatible with the proto package it is being compiled against.
// A compilation error at this line likely means your copy of the
// proto package needs to be updated.
const _ = proto.ProtoPackageIsVersion3 // please upgrade the proto package
type String struct {
Value string `protobuf:"bytes,1,opt,name=value,proto3" json:"value,omitempty"`
XXX_NoUnkeyedLiteral struct{} `json:"-"`
XXX_unrecognized []byte `json:"-"`
XXX_sizecache int32 `json:"-"`
}
func (m *String) Reset() { *m = String{} }
func (m *String) String() string { return proto.CompactTextString(m) }
func (*String) ProtoMessage() {}
func (*String) Descriptor() ([]byte, []int) {
return fileDescriptor_61ef911816e0a8ce, []int{0}
}
func (m *String) XXX_Unmarshal(b []byte) error {
return xxx_messageInfo_String.Unmarshal(m, b)
}
func (m *String) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) {
return xxx_messageInfo_String.Marshal(b, m, deterministic)
}
func (m *String) XXX_Merge(src proto.Message) {
xxx_messageInfo_String.Merge(m, src)
}
func (m *String) XXX_Size() int {
return xxx_messageInfo_String.Size(m)
}
func (m *String) XXX_DiscardUnknown() {
xxx_messageInfo_String.DiscardUnknown(m)
}
var xxx_messageInfo_String proto.InternalMessageInfo
func (m *String) GetValue() string {
if m != nil {
return m.Value
}
return ""
}
func init() {
proto.RegisterType((*String)(nil), "hello.String")
}
func init() { proto.RegisterFile("hello.proto", fileDescriptor_61ef911816e0a8ce) }
var fileDescriptor_61ef911816e0a8ce = []byte{
// 103 bytes of a gzipped FileDescriptorProto
0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0xe2, 0xce, 0x48, 0xcd, 0xc9,
0xc9, 0xd7, 0x2b, 0x28, 0xca, 0x2f, 0xc9, 0x17, 0x62, 0x05, 0x73, 0x94, 0xe4, 0xb8, 0xd8, 0x82,
0x4b, 0x8a, 0x32, 0xf3, 0xd2, 0x85, 0x44, 0xb8, 0x58, 0xcb, 0x12, 0x73, 0x4a, 0x53, 0x25, 0x18,
0x15, 0x18, 0x35, 0x38, 0x83, 0x20, 0x1c, 0x23, 0x53, 0x2e, 0x1e, 0x0f, 0x90, 0xc2, 0xe0, 0xd4,
0xa2, 0xb2, 0xcc, 0xe4, 0x54, 0x21, 0x55, 0x2e, 0x56, 0x30, 0x5f, 0x88, 0x57, 0x0f, 0x62, 0x1a,
0x44, 0xb7, 0x14, 0x2a, 0x37, 0x89, 0x0d, 0x6c, 0x89, 0x31, 0x20, 0x00, 0x00, 0xff, 0xff, 0x15,
0xe8, 0xb1, 0xcc, 0x73, 0x00, 0x00, 0x00,
}
type HelloServiceInterface interface {
Hello(in *String, out *String) error
}
func RegisterHelloService(srv *rpc.Server, x HelloServiceInterface) error {
if err := srv.RegisterName("HelloService", x); err != nil {
return err
}
return nil
}
type HelloServiceClient struct {
*rpc.Client
}
var _ HelloServiceInterface = (*HelloServiceClient)(nil)
func DialHelloService(network, address string) (*HelloServiceClient, error) {
c, err := rpc.Dial(network, address)
if err != nil {
return nil, err
}
return &HelloServiceClient{Client: c}, nil
}
func (p *HelloServiceClient) Hello(in *String, out *String) error {
return p.Client.Call("HelloService.Hello", in, out)
}