golang的pgx驱动提供了大约70种PostgreSQL类型支持,但还是有一些类型没有涵盖,本文介绍如何自己编写代码支持特殊的类型。本文以PostGIS的Geometry类型为例。
需要实现:Set、Get、AssignTo、DecodeText、DecodeBinary、EncodeText、EncodeBinary、Scan、Value、MarshalJSON、UnmarshalJSON方法。
//PostGIS Geometry实现,目前仅支持Point2Dimport ( "bytes""database/sql/driver""encoding/binary""encoding/hex""errors""fmt""github.com/jackc/pgtype""math""strconv""strings") const ( TypeGeoPoint=iota+1TypeGeoPointMTypeGeoPointZ) typeGeometrystruct { Xfloat64Yfloat64Statuspgtype.Status} funcAppendByte(buf []byte, nbyte) []byte { buf=append(buf, n) returnbuf} funcAppendUint16(buf []byte, nuint16) []byte { wp :=len(buf) buf=append(buf, 0, 0) binary.LittleEndian.PutUint16(buf[wp:], n) returnbuf} funcAppendUint32(buf []byte, nuint32) []byte { wp :=len(buf) buf=append(buf, 0, 0, 0, 0) binary.LittleEndian.PutUint32(buf[wp:], n) returnbuf} funcAppendUint64(buf []byte, nuint64) []byte { wp :=len(buf) buf=append(buf, 0, 0, 0, 0, 0, 0, 0, 0) binary.LittleEndian.PutUint64(buf[wp:], n) returnbuf} funcAppendInt16(buf []byte, nint16) []byte { returnAppendUint16(buf, uint16(n)) } funcAppendInt32(buf []byte, nint32) []byte { returnAppendUint32(buf, uint32(n)) } funcAppendInt64(buf []byte, nint64) []byte { returnAppendUint64(buf, uint64(n)) } funcSetInt32(buf []byte, nint32) { binary.LittleEndian.PutUint32(buf, uint32(n)) } func (dst*Geometry) Set(srcinterface{}) error { ifsrc==nil { dst.Status=pgtype.Nullreturnnil } err :=fmt.Errorf("cannot convert %v to Geometry", src) varp*Geometryswitchvalue :=src.(type) { casestring: p, err=parseGeometry([]byte(value)) case []byte: p, err=parseGeometry(value) default: returnerr } iferr!=nil { returnerr } *dst=*preturnnil} funcparseGeometry(src []byte) (*Geometry, error) { ifsrc==nil||bytes.Compare(src, []byte("null")) ==0 { return&Geometry{Status: pgtype.Null}, nil } iflen(src) <5 { returnnil, fmt.Errorf("invalid length for point: %v", len(src)) } ifsrc[0] =='"'&&src[len(src)-1] =='"' { src=src[1 : len(src)-1] } parts :=strings.SplitN(string(src[0:len(src)-1]), ",", 2) iflen(parts) <2 { returnnil, fmt.Errorf("invalid format for point") } x, err :=strconv.ParseFloat(parts[0], 64) iferr!=nil { returnnil, err } y, err :=strconv.ParseFloat(parts[1], 64) iferr!=nil { returnnil, err } return&Geometry{X: x, Y: y, Status: pgtype.Present}, nil} func (dstGeometry) Get() interface{} { switchdst.Status { casepgtype.Present: returndstcasepgtype.Null: returnnildefault: returndst.Status } } func (src*Geometry) AssignTo(dstinterface{}) error { returnfmt.Errorf("cannot assign %v to %T", src, dst) } func (dst*Geometry) DecodeText(ci*pgtype.ConnInfo, src []byte) error { ifsrc==nil { *dst=Geometry{Status: pgtype.Null} returnnil } s, _ :=hex.DecodeString(string(src)) l :=len(s) begin :=l-2*8//只取坐标值parts :=s[begin:] iflen(parts) !=16 { returnfmt.Errorf("invalid format for geometry") } x :=binary.LittleEndian.Uint64(parts) y :=binary.LittleEndian.Uint64(parts[8:]) *dst=Geometry{X: math.Float64frombits(x), Y: math.Float64frombits(y), Status: pgtype.Present} //*dst = Geometry{X: x, Y: y, Status: pgtype.Present}returnnil} func (dst*Geometry) DecodeBinary(ci*pgtype.ConnInfo, src []byte) error { ifsrc==nil { *dst=Geometry{Status: pgtype.Null} returnnil } l :=len(src) begin :=l-2*8parts :=src[begin:] iflen(parts) !=16 { returnfmt.Errorf("invalid length for geometry: %v", len(src)) } x :=binary.LittleEndian.Uint64(parts) y :=binary.LittleEndian.Uint64(parts[8:]) *dst=Geometry{ X: math.Float64frombits(x), Y: math.Float64frombits(y), Status: pgtype.Present, } returnnil} func (srcGeometry) EncodeText(ci*pgtype.ConnInfo, buf []byte) ([]byte, error) { switchsrc.Status { casepgtype.Null: returnnil, nilcasepgtype.Undefined: returnnil, errors.New("cannot encode status undefined") } buf=AppendByte(buf, 0x01) //1-不带SRID, 0x20000000 - 带SRID//坐标类型 point(1)buf=AppendInt32(buf, 0x20000000|1) //WGS 84 SRID=4326buf=AppendInt32(buf, 0x10E6) buf=AppendUint64(buf, math.Float64bits(src.X)) buf=AppendUint64(buf, math.Float64bits(src.Y)) s :=hex.EncodeToString(buf) return []byte(s), nil} func (srcGeometry) EncodeBinary(ci*pgtype.ConnInfo, buf []byte) ([]byte, error) { switchsrc.Status { casepgtype.Null: returnnil, nilcasepgtype.Undefined: returnnil, errors.New("cannot encode status undefined") } buf=AppendByte(buf, 0x01) //1-不带SRID, 0x20000000 - 带SRID//坐标类型 point(1)buf=AppendInt32(buf, 0x20000000|1) //WGS 84 SRID=4326buf=AppendInt32(buf, 0x10E6) buf=AppendUint64(buf, math.Float64bits(src.X)) buf=AppendUint64(buf, math.Float64bits(src.Y)) returnbuf, nil} // Scan implements the database/sql Scanner interface.func (dst*Geometry) Scan(srcinterface{}) error { ifsrc==nil { *dst=Geometry{Status: pgtype.Null} returnnil } switchsrc :=src.(type) { casestring: returndst.DecodeText(nil, []byte(src)) case []byte: srcCopy :=make([]byte, len(src)) copy(srcCopy, src) returndst.DecodeText(nil, srcCopy) } returnfmt.Errorf("cannot scan %T", src) } // Value implements the database/sql/driver Valuer interface.func (srcGeometry) Value() (driver.Value, error) { returnpgtype.EncodeValueText(src) } func (srcGeometry) MarshalJSON() ([]byte, error) { switchsrc.Status { casepgtype.Present: varbuffbytes.Bufferbuff.WriteByte('"') buff.WriteString(fmt.Sprintf("(%g,%g)", src.X, src.Y)) buff.WriteByte('"') returnbuff.Bytes(), nilcasepgtype.Null: return []byte("null"), nilcasepgtype.Undefined: returnnil, errors.New("cannot encode status undefined") } returnnil, errors.New("invalid status") } func (dst*Geometry) UnmarshalJSON(geometry []byte) error { p, err :=parseGeometry(geometry) iferr!=nil { returnerr } *dst=*preturnnil}
使用前,需要注册自定义的类型,在程序的初始化部分加上以下代码:
urlExample :="postgres://username:password@localhost:5432/database_name"vargeometryOiduint32config, err :=pgxpool.ParseConfig(urlExample) iferr!=nil { log.Panic("parse database config failed", zap.Error(err)) } config.AfterConnect=func(ctxcontext.Context, conn*pgx.Conn) error { //注册PostGIS geometry类型ifgeometryOid==0 { //取得geometry类型的OIDerr=conn.QueryRow(ctx, "select 'geometry'::regtype::oid").Scan(&geometryOid) iferr!=nil { log.Panic("get geometry oid failed", zap.Error(err)) } } ci :=conn.ConnInfo() ci.RegisterDataType(pgtype.DataType{Value: &mypgtype.Geometry{}, Name: "geometry", OID: geometryOid}) returnnil}