A ton of work getting towards dms.

This commit also points to my fork of discordgo which makes it look like the
official client which is the only way to get the actually contents of a dm
when not authorized as a bot.
This commit is contained in:
Gary Kramlich
2022-01-25 23:22:20 -06:00
parent 167fdede1f
commit 680f7bdbea
15 changed files with 478 additions and 57 deletions

View File

@@ -5,7 +5,7 @@ import (
) )
func (b *Bridge) updateBotProfile() { func (b *Bridge) updateBotProfile() {
cfg := b.config.Appservice.Bot cfg := b.Config.Appservice.Bot
// Set the bot's avatar. // Set the bot's avatar.
if cfg.Avatar != "" { if cfg.Avatar != "" {

View File

@@ -21,7 +21,7 @@ const (
) )
type Bridge struct { type Bridge struct {
config *config.Config Config *config.Config
log log.Logger log log.Logger
@@ -44,6 +44,8 @@ type Bridge struct {
puppets map[string]*Puppet puppets map[string]*Puppet
puppetsLock sync.Mutex puppetsLock sync.Mutex
StateStore *database.SQLStateStore
} }
func New(cfg *config.Config) (*Bridge, error) { func New(cfg *config.Config) (*Bridge, error) {
@@ -73,12 +75,17 @@ func New(cfg *config.Config) (*Bridge, error) {
return nil, err return nil, err
} }
// Create the state store
logger.Debugln("Initializing state store")
stateStore := database.NewSQLStateStore(db)
appservice.StateStore = stateStore
// Create the bridge. // Create the bridge.
bridge := &Bridge{ bridge := &Bridge{
as: appservice, as: appservice,
db: db, db: db,
bot: bot, bot: bot,
config: cfg, Config: cfg,
log: logger, log: logger,
usersByMXID: make(map[id.UserID]*User), usersByMXID: make(map[id.UserID]*User),
@@ -88,6 +95,10 @@ func New(cfg *config.Config) (*Bridge, error) {
portalsByMXID: make(map[id.RoomID]*Portal), portalsByMXID: make(map[id.RoomID]*Portal),
portalsByID: make(map[database.PortalKey]*Portal), portalsByID: make(map[database.PortalKey]*Portal),
puppets: make(map[string]*Puppet),
StateStore: stateStore,
} }
// Setup the event processors // Setup the event processors

View File

@@ -78,7 +78,7 @@ func (mh *matrixHandler) handleMessage(evt *event.Event) {
content.RemoveReplyFallback() content.RemoveReplyFallback()
if content.MsgType == event.MsgText { if content.MsgType == event.MsgText {
prefix := mh.bridge.config.Bridge.CommandPrefix prefix := mh.bridge.Config.Bridge.CommandPrefix
hasPrefix := strings.HasPrefix(content.Body, prefix) hasPrefix := strings.HasPrefix(content.Body, prefix)
if hasPrefix { if hasPrefix {
@@ -150,16 +150,16 @@ func (mh *matrixHandler) handleBotInvite(evt *event.Event) {
// Wait to send the welcome message until we're sure we're not in an empty // Wait to send the welcome message until we're sure we're not in an empty
// room. // room.
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.config.Bridge.ManagementRoomText.Welcome) mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Welcome)
if evt.RoomID == user.ManagementRoom { if evt.RoomID == user.ManagementRoom {
if user.HasSession() { if user.HasSession() {
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.config.Bridge.ManagementRoomText.Connected) mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.Connected)
} else { } else {
mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.config.Bridge.ManagementRoomText.NotConnected) mh.sendNoticeWithmarkdown(evt.RoomID, mh.bridge.Config.Bridge.ManagementRoomText.NotConnected)
} }
additionalHelp := mh.bridge.config.Bridge.ManagementRoomText.AdditionalHelp additionalHelp := mh.bridge.Config.Bridge.ManagementRoomText.AdditionalHelp
if additionalHelp != "" { if additionalHelp != "" {
mh.sendNoticeWithmarkdown(evt.RoomID, additionalHelp) mh.sendNoticeWithmarkdown(evt.RoomID, additionalHelp)
} }
@@ -201,8 +201,6 @@ func (mh *matrixHandler) handleMembership(evt *event.Event) {
mh.handlePuppetInvite(evt, user, puppet) mh.handlePuppetInvite(evt, user, puppet)
} }
mh.log.Warnln("no existing portal for", evt.RoomID)
return return
} }

View File

@@ -2,9 +2,12 @@ package bridge
import ( import (
"fmt" "fmt"
"sync"
"github.com/bwmarrin/discordgo"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@@ -12,7 +15,12 @@ import (
"gitlab.com/beeper/discord/database" "gitlab.com/beeper/discord/database"
) )
type PortalMatrixMessage struct { type portalDiscordMessage struct {
msg interface{}
user *User
}
type portalMatrixMessage struct {
evt *event.Event evt *event.Event
user *User user *User
} }
@@ -23,9 +31,18 @@ type Portal struct {
bridge *Bridge bridge *Bridge
log log.Logger log log.Logger
matrixMessages chan PortalMatrixMessage channelType discordgo.ChannelType
roomCreateLock sync.Mutex
discordMessages chan portalDiscordMessage
matrixMessages chan portalMatrixMessage
} }
var (
portalCreationDummyEvent = event.Type{Type: "fi.mau.dummy.portal_created", Class: event.MessageEventType}
)
func (b *Bridge) loadPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal { func (b *Bridge) loadPortal(dbPortal *database.Portal, key *database.PortalKey) *Portal {
// If we weren't given a portal we'll attempt to create it if a key was // If we weren't given a portal we'll attempt to create it if a key was
// provided. // provided.
@@ -63,13 +80,26 @@ func (b *Bridge) GetPortalByMXID(mxid id.RoomID) *Portal {
return portal return portal
} }
func (b *Bridge) GetPortalByID(key database.PortalKey) *Portal {
b.portalsLock.Lock()
defer b.portalsLock.Unlock()
portal, ok := b.portalsByID[key]
if !ok {
return b.loadPortal(b.db.Portal.GetByID(key), &key)
}
return portal
}
func (b *Bridge) NewPortal(dbPortal *database.Portal) *Portal { func (b *Bridge) NewPortal(dbPortal *database.Portal) *Portal {
portal := &Portal{ portal := &Portal{
Portal: dbPortal, Portal: dbPortal,
bridge: b, bridge: b,
log: b.log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)), log: b.log.Sub(fmt.Sprintf("Portal/%s", dbPortal.Key)),
matrixMessages: make(chan PortalMatrixMessage, b.config.Bridge.PortalMessageBuffer), discordMessages: make(chan portalDiscordMessage, b.Config.Bridge.PortalMessageBuffer),
matrixMessages: make(chan portalMatrixMessage, b.Config.Bridge.PortalMessageBuffer),
} }
go portal.messageLoop() go portal.messageLoop()
@@ -91,13 +121,15 @@ func (p *Portal) messageLoop() {
for { for {
select { select {
case msg := <-p.matrixMessages: case msg := <-p.matrixMessages:
p.log.Infoln("got message", msg) p.log.Infoln("got matrix message", msg)
case msg := <-p.discordMessages:
p.handleDiscordMessage(msg)
} }
} }
} }
func (p *Portal) IsPrivateChat() bool { func (p *Portal) IsPrivateChat() bool {
return false return (p.channelType == discordgo.ChannelTypeDM || p.channelType == discordgo.ChannelTypeGroupDM)
} }
func (p *Portal) MainIntent() *appservice.IntentAPI { func (p *Portal) MainIntent() *appservice.IntentAPI {
@@ -107,3 +139,142 @@ func (p *Portal) MainIntent() *appservice.IntentAPI {
return p.bridge.bot return p.bridge.bot
} }
func (p *Portal) createMatrixRoom(user *User, channel *discordgo.Channel) error {
p.roomCreateLock.Lock()
defer p.roomCreateLock.Unlock()
// If we have a matrix id the room should exist so we have nothing to do.
if p.MXID != "" {
return nil
}
p.channelType = channel.Type
intent := p.MainIntent()
if err := intent.EnsureRegistered(); err != nil {
return err
}
if p.IsPrivateChat() {
puppet := p.bridge.GetPuppetByID(p.Key.ID)
puppet.SyncContact(user)
p.Name = puppet.DisplayName
p.Avatar = puppet.Avatar
p.AvatarURL = puppet.AvatarURL
}
p.log.Infoln("Creating Matrix room. Info source:", p.Portal.Key.ID)
initialState := []*event.Event{}
creationContent := make(map[string]interface{})
// if !portal.bridge.Config.Bridge.FederateRooms {
creationContent["m.federate"] = false
// }
var invite []id.UserID
if p.IsPrivateChat() {
invite = append(invite, p.bridge.bot.UserID)
}
resp, err := intent.CreateRoom(&mautrix.ReqCreateRoom{
Visibility: "private",
Name: p.Name,
Topic: p.Topic,
Preset: "private_chat",
IsDirect: p.IsPrivateChat(),
InitialState: initialState,
CreationContent: creationContent,
})
if err != nil {
return err
}
p.MXID = resp.RoomID
p.Update()
p.bridge.portalsLock.Lock()
p.bridge.portalsByMXID[p.MXID] = p
p.bridge.portalsLock.Unlock()
p.ensureUserInvited(user)
// if p.IsPrivateChat() {
// puppet := user.bridge.GetPuppetByID(p.Key.ID)
// if p.bridge.Config.Bridge.Encryption.Default {
// err = portal.bridge.Bot.EnsureJoined(portal.MXID)
// if err != nil {
// portal.log.Errorln("Failed to join created portal with bridge bot for e2be:", err)
// }
// }
// user.UpdateDirectChats(map[id.UserID][]id.RoomID{puppet.MXID: {portal.MXID}})
// }
firstEventResp, err := p.MainIntent().SendMessageEvent(p.MXID, portalCreationDummyEvent, struct{}{})
if err != nil {
p.log.Errorln("Failed to send dummy event to mark portal creation:", err)
} else {
p.FirstEventID = firstEventResp.EventID
p.Update()
}
return nil
}
func (p *Portal) handleDiscordMessage(msg portalDiscordMessage) {
if p.MXID == "" {
p.log.Debugln("Creating Matrix room from incoming message")
discordMsg := msg.msg.(*discordgo.MessageCreate)
channel, err := msg.user.Session.Channel(discordMsg.ChannelID)
if err != nil {
p.log.Errorln("Failed to find channel for message:", err)
return
}
if err := p.createMatrixRoom(msg.user, channel); err != nil {
p.log.Errorln("Failed to create portal room:", err)
return
}
}
switch msg.msg.(type) {
case *discordgo.MessageCreate:
p.handleMessage(msg.msg.(*discordgo.MessageCreate).Message)
default:
p.log.Warnln("unknown message type")
}
}
func (p *Portal) ensureUserInvited(user *User) bool {
return user.ensureInvited(p.MainIntent(), p.MXID, p.IsPrivateChat())
}
func (p *Portal) handleMessage(msg *discordgo.Message) {
if p.MXID == "" {
p.log.Warnln("handle message called without a valid portal")
return
}
// TODO: Check if we already got the message
p.log.Debugln("content", msg.Content)
p.log.Debugln("embeds", msg.Embeds)
p.log.Debugln("msg", msg)
content := &event.MessageEventContent{
Body: msg.Content,
MsgType: event.MsgText,
}
resp, err := p.MainIntent().SendMessageEvent(p.MXID, event.EventMessage, content)
p.log.Warnln("response:", resp)
p.log.Warnln("error:", err)
}

View File

@@ -3,6 +3,7 @@ package bridge
import ( import (
"fmt" "fmt"
"regexp" "regexp"
"sync"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
@@ -18,6 +19,8 @@ type Puppet struct {
log log.Logger log log.Logger
MXID id.UserID MXID id.UserID
syncLock sync.Mutex
} }
var userIDRegex *regexp.Regexp var userIDRegex *regexp.Regexp
@@ -36,8 +39,8 @@ func (b *Bridge) ParsePuppetMXID(mxid id.UserID) (string, bool) {
if userIDRegex == nil { if userIDRegex == nil {
pattern := fmt.Sprintf( pattern := fmt.Sprintf(
"^@%s:%s$", "^@%s:%s$",
b.config.Bridge.FormatUsername("([0-9]+)"), b.Config.Bridge.FormatUsername("([0-9]+)"),
b.config.Homeserver.Domain, b.Config.Homeserver.Domain,
) )
userIDRegex = regexp.MustCompile(pattern) userIDRegex = regexp.MustCompile(pattern)
@@ -82,11 +85,27 @@ func (b *Bridge) GetPuppetByID(id string) *Puppet {
func (b *Bridge) FormatPuppetMXID(did string) id.UserID { func (b *Bridge) FormatPuppetMXID(did string) id.UserID {
return id.NewUserID( return id.NewUserID(
b.config.Bridge.FormatUsername(did), b.Config.Bridge.FormatUsername(did),
b.config.Homeserver.Domain, b.Config.Homeserver.Domain,
) )
} }
func (p *Puppet) DefaultIntent() *appservice.IntentAPI { func (p *Puppet) DefaultIntent() *appservice.IntentAPI {
return p.bridge.as.Intent(p.MXID) return p.bridge.as.Intent(p.MXID)
} }
func (p *Puppet) SyncContact(user *User) {
p.syncLock.Lock()
defer p.syncLock.Unlock()
dUser, err := user.Session.User(p.ID)
if err != nil {
p.log.Warnfln("failed to sync puppet %s: %v", p.ID, err)
return
}
p.DisplayName = p.bridge.Config.Bridge.FormatDisplayname(dUser)
p.Update()
}

View File

@@ -1,10 +1,14 @@
package bridge package bridge
import ( import (
"errors"
"strings"
"github.com/bwmarrin/discordgo" "github.com/bwmarrin/discordgo"
"github.com/skip2/go-qrcode" "github.com/skip2/go-qrcode"
log "maunium.net/go/maulogger/v2" log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/appservice" "maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@@ -177,13 +181,131 @@ func (u *User) Login(token string) error {
} }
func (u *User) Connect() error { func (u *User) Connect() error {
u.log.Debugln("connecting to discord")
// get our user info
user, err := u.User.Session.User("@me")
if err != nil {
return err
}
u.User.ID = user.ID
// Add our event handlers
u.User.Session.AddHandler(u.connectedHandler)
u.User.Session.AddHandler(u.disconnectedHandler)
u.User.Session.AddHandler(u.channelCreateHandler)
u.User.Session.AddHandler(u.channelDeleteHandler)
u.User.Session.AddHandler(u.channelPinsUpdateHandler)
u.User.Session.AddHandler(u.channelUpdateHandler)
u.User.Session.AddHandler(u.messageHandler) u.User.Session.AddHandler(u.messageHandler)
u.log.Warnln("logged in, opening websocket") // u.User.Session.Identify.Capabilities = 125
// // Setup our properties
// u.User.Session.Identify.Properties = discordgo.IdentifyProperties{
// OS: "Windows",
// OSVersion: "10",
// Browser: "Chrome",
// BrowserUserAgent: "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/92.0.4515.159 Safari/537.36",
// BrowserVersion: "92.0.4515.159",
// Referrer: "https://discord.com/channels/@me",
// ReferringDomain: "discord.com",
// ClientBuildNumber: "83364",
// ReleaseChannel: "stable",
// }
u.User.Session.Identify.Presence.Status = "online"
return u.User.Session.Open() return u.User.Session.Open()
} }
func (u *User) messageHandler(s *discordgo.Session, m *discordgo.MessageCreate) { func (u *User) connectedHandler(s *discordgo.Session, c *discordgo.Connect) {
u.log.Warnln("received message", m) u.log.Debugln("connected to discord")
}
func (u *User) disconnectedHandler(s *discordgo.Session, d *discordgo.Disconnect) {
u.log.Debugln("disconnected from discord")
}
func (u *User) channelCreateHandler(s *discordgo.Session, c *discordgo.ChannelCreate) {
key := database.NewPortalKey(u.User.ID, c.ID)
portal := u.bridge.GetPortalByID(key)
portal.Name = c.Name
portal.Topic = c.Topic
if c.Icon != "" {
u.log.Debugln("channel icon", c.Icon)
}
portal.Update()
portal.createMatrixRoom(u, c.Channel)
}
func (u *User) channelDeleteHandler(s *discordgo.Session, c *discordgo.ChannelDelete) {
u.log.Debugln("channel delete handler")
}
func (u *User) channelPinsUpdateHandler(s *discordgo.Session, c *discordgo.ChannelPinsUpdate) {
u.log.Debugln("channel pins update")
}
func (u *User) channelUpdateHandler(s *discordgo.Session, c *discordgo.ChannelUpdate) {
key := database.NewPortalKey(u.User.ID, c.ID)
portal := u.bridge.GetPortalByID(key)
portal.Name = c.Name
portal.Topic = c.Topic
u.log.Debugln("channel icon", c.Icon)
portal.Update()
u.log.Debugln("channel update")
}
func (u *User) messageHandler(s *discordgo.Session, m *discordgo.MessageCreate) {
if m.GuildID != "" {
u.log.Debugln("ignoring guild build messaged")
return
}
key := database.NewPortalKey(u.User.ID, m.ChannelID)
portal := u.bridge.GetPortalByID(key)
msg := portalDiscordMessage{
msg: m,
user: u,
}
portal.discordMessages <- msg
}
func (u *User) ensureInvited(intent *appservice.IntentAPI, roomID id.RoomID, isDirect bool) bool {
ret := false
inviteContent := event.Content{
Parsed: &event.MemberEventContent{
Membership: event.MembershipInvite,
IsDirect: isDirect,
},
Raw: map[string]interface{}{},
}
resp, err := intent.SendStateEvent(roomID, event.StateMember, u.MXID.String(), &inviteContent)
u.log.Warnfln("resp: %#v", resp)
var httpErr mautrix.HTTPError
if err != nil && errors.As(err, &httpErr) && httpErr.RespError != nil && strings.Contains(httpErr.RespError.Err, "is already in the room") {
u.bridge.StateStore.SetMembership(roomID, u.MXID, event.MembershipJoin)
ret = true
} else if err != nil {
u.log.Warnfln("Failed to invite user to %s: %v", roomID, err)
} else {
ret = true
}
return ret
} }

View File

@@ -1,12 +1,15 @@
package config package config
import ( import (
"bytes" "strings"
"text/template" "text/template"
"github.com/bwmarrin/discordgo"
) )
type bridge struct { type bridge struct {
UsernameTemplate string `yaml:"username_template"` UsernameTemplate string `yaml:"username_template"`
DisplaynameTemplate string `yaml:"displayname_template"`
CommandPrefix string `yaml:"command_prefix"` CommandPrefix string `yaml:"command_prefix"`
@@ -14,7 +17,8 @@ type bridge struct {
PortalMessageBuffer int `yaml:"portal_message_buffer"` PortalMessageBuffer int `yaml:"portal_message_buffer"`
usernameTemplate *template.Template `yaml:"-"` usernameTemplate *template.Template `yaml:"-"`
displaynameTemplate *template.Template `yaml:"-"`
} }
func (b *bridge) validate() error { func (b *bridge) validate() error {
@@ -24,15 +28,24 @@ func (b *bridge) validate() error {
b.UsernameTemplate = "Discord_{{.}}" b.UsernameTemplate = "Discord_{{.}}"
} }
if b.PortalMessageBuffer <= 0 {
b.PortalMessageBuffer = 128
}
b.usernameTemplate, err = template.New("username").Parse(b.UsernameTemplate) b.usernameTemplate, err = template.New("username").Parse(b.UsernameTemplate)
if err != nil { if err != nil {
return err return err
} }
if b.DisplaynameTemplate == "" {
b.DisplaynameTemplate = "{{.Username}}#{{.Discriminator}} (D){{if .Bot}} (bot){{end}}"
}
b.displaynameTemplate, err = template.New("displayname").Parse(b.DisplaynameTemplate)
if err != nil {
return err
}
if b.PortalMessageBuffer <= 0 {
b.PortalMessageBuffer = 128
}
if b.CommandPrefix == "" { if b.CommandPrefix == "" {
b.CommandPrefix = "!dis" b.CommandPrefix = "!dis"
} }
@@ -60,9 +73,35 @@ func (b *bridge) UnmarshalYAML(unmarshal func(interface{}) error) error {
} }
func (b bridge) FormatUsername(userid string) string { func (b bridge) FormatUsername(userid string) string {
var buffer bytes.Buffer var buffer strings.Builder
b.usernameTemplate.Execute(&buffer, userid) b.usernameTemplate.Execute(&buffer, userid)
return buffer.String() return buffer.String()
} }
type simplfiedUser struct {
Username string
Discriminator string
Locale string
Verified bool
MFAEnabled bool
Bot bool
System bool
}
func (b bridge) FormatDisplayname(user *discordgo.User) string {
var buffer strings.Builder
b.displaynameTemplate.Execute(&buffer, simplfiedUser{
Username: user.Username,
Discriminator: user.Discriminator,
Locale: user.Locale,
Verified: user.Verified,
MFAEnabled: user.MFAEnabled,
Bot: user.Bot,
System: user.System,
})
return buffer.String()
}

View File

@@ -1,20 +1,22 @@
CREATE TABLE IF NOT EXISTS portal ( CREATE TABLE portal (
id TEXT, id TEXT,
receiver TEXT, channel_id TEXT,
mxid TEXT UNIQUE, mxid TEXT UNIQUE,
name TEXT NOT NULL, name TEXT NOT NULL,
topic TEXT NOT NULL, topic TEXT NOT NULL,
avatar TEXT NOT NULL, avatar TEXT NOT NULL,
avatar_url TEXT NOT NULL, avatar_url TEXT,
PRIMARY KEY (id, receiver) first_event_id TEXT,
PRIMARY KEY (id, channel_id)
); );
CREATE TABLE IF NOT EXISTS puppet ( CREATE TABLE puppet (
id TEXT PRIMARY KEY, id TEXT PRIMARY KEY,
displayname TEXT, display_name TEXT,
avatar TEXT, avatar TEXT,
avatar_url TEXT, avatar_url TEXT,
@@ -22,7 +24,7 @@ CREATE TABLE IF NOT EXISTS puppet (
enable_presence BOOLEAN NOT NULL DEFAULT true enable_presence BOOLEAN NOT NULL DEFAULT true
); );
CREATE TABLE IF NOT EXISTS user ( CREATE TABLE user (
mxid TEXT PRIMARY KEY, mxid TEXT PRIMARY KEY,
id TEXT UNIQUE, id TEXT UNIQUE,
@@ -30,3 +32,21 @@ CREATE TABLE IF NOT EXISTS user (
token TEXT token TEXT
); );
CREATE TABLE mx_user_profile (
room_id TEXT,
user_id TEXT,
membership TEXT NOT NULL,
displayname TEXT,
avatar_url TEXT,
PRIMARY KEY (room_id, user_id)
);
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
);
CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
power_levels TEXT
);

View File

@@ -19,12 +19,14 @@ type Portal struct {
Avatar string Avatar string
AvatarURL id.ContentURI AvatarURL id.ContentURI
FirstEventID id.EventID
} }
func (p *Portal) Scan(row Scannable) *Portal { func (p *Portal) Scan(row Scannable) *Portal {
var mxid, avatarURL sql.NullString var mxid, avatarURL, firstEventID sql.NullString
err := row.Scan(&p.Key.ID, &p.Key.Receiver, &mxid, &p.Name, &p.Topic, &p.Avatar, &avatarURL) err := row.Scan(&p.Key.ID, &p.Key.ChannelID, &mxid, &p.Name, &p.Topic, &p.Avatar, &avatarURL, &firstEventID)
if err != nil { if err != nil {
if err != sql.ErrNoRows { if err != sql.ErrNoRows {
p.log.Errorln("Database scan failed:", err) p.log.Errorln("Database scan failed:", err)
@@ -35,19 +37,34 @@ func (p *Portal) Scan(row Scannable) *Portal {
p.MXID = id.RoomID(mxid.String) p.MXID = id.RoomID(mxid.String)
p.AvatarURL, _ = id.ParseContentURI(avatarURL.String) p.AvatarURL, _ = id.ParseContentURI(avatarURL.String)
p.FirstEventID = id.EventID(firstEventID.String)
return p return p
} }
func (p *Portal) Insert() { func (p *Portal) Insert() {
query := "INSERT INTO portal" + query := "INSERT INTO portal" +
" (id, receiver, mxid, name, topic, avatar, avatar_url)" + " (id, mxid, channel_id, name, topic, avatar, avatar_url, first_event_id)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7)" " VALUES ($1, $2, $3, $4, $5, $6, $7, $8)"
_, err := p.db.Exec(query, p.Key.ID, p.Key.Receiver, p.MXID, _, err := p.db.Exec(query, p.Key.ID, p.MXID, p.Key.ChannelID,
p.Name, p.Topic, p.Avatar, p.AvatarURL.String()) p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.FirstEventID.String())
if err != nil { if err != nil {
p.log.Warnfln("Failed to insert %s: %v", p.Key, err) p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
} }
} }
func (p *Portal) Update() {
query := "UPDATE portal SET" +
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, first_event_id=$6" +
" WHERE id=$7 AND channel_id=$8"
_, err := p.db.Exec(query, p.MXID, p.Name, p.Topic, p.Avatar,
p.AvatarURL.String(), p.FirstEventID.String(), p.Key.ID,
p.Key.ChannelID)
if err != nil {
p.log.Warnfln("Failed to update %s: %v", p.Key, err)
}
}

View File

@@ -1,13 +1,20 @@
package database package database
type PortalKey struct { type PortalKey struct {
ID string ID string
Receiver string ChannelID string
}
func NewPortalKey(id, channelID string) PortalKey {
return PortalKey{
ID: id,
ChannelID: channelID,
}
} }
func (key PortalKey) String() string { func (key PortalKey) String() string {
if key.Receiver == key.ID { if key.ChannelID == key.ID {
return key.ID return key.ID
} }
return key.ID + "-" + key.Receiver return key.ID + "-" + key.ChannelID
} }

View File

@@ -21,8 +21,8 @@ func (pq *PortalQuery) GetAll() []*Portal {
return pq.getAll("SELECT * FROM portal") return pq.getAll("SELECT * FROM portal")
} }
func (pq *PortalQuery) GetByDID(key PortalKey) *Portal { func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
return pq.get("SELECT * FROM portal WHERE did=$1 AND receiver=$2", key.ID, key.Receiver) return pq.get("SELECT * FROM portal WHERE id=$1 AND channel_id=$2", key.ID, key.ChannelID)
} }
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal { func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
@@ -30,7 +30,7 @@ func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
} }
func (pq *PortalQuery) GetAllByDID(did string) []*Portal { func (pq *PortalQuery) GetAllByDID(did string) []*Portal {
return pq.getAll("SELECT * FROM portal WHERE did=$1", did) return pq.getAll("SELECT * FROM portal WHERE id=$1", did)
} }
func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal { func (pq *PortalQuery) getAll(query string, args ...interface{}) []*Portal {

View File

@@ -54,3 +54,16 @@ func (p *Puppet) Insert() {
p.log.Warnfln("Failed to insert %s: %v", p.ID, err) p.log.Warnfln("Failed to insert %s: %v", p.ID, err)
} }
} }
func (p *Puppet) Update() {
query := "UPDATE puppet" +
" SET display_name=$1, avatar=$2, avatar_url=$3, enable_presence=$4" +
" WHERE id=$5"
_, err := p.db.Exec(query, p.DisplayName, p.Avatar, p.AvatarURL.String(),
p.EnablePresence, p.ID)
if err != nil {
p.log.Warnfln("Failed to update %s: %v", p.ID, err)
}
}

View File

@@ -19,7 +19,7 @@ func (pq *PuppetQuery) New() *Puppet {
} }
func (pq *PuppetQuery) Get(id string) *Puppet { func (pq *PuppetQuery) Get(id string) *Puppet {
row := pq.db.QueryRow("SELECT id, displayname, avatar, avatar_url, enable_presence FROM puppet WHERE id=$1", id) row := pq.db.QueryRow("SELECT id, display_name, avatar, avatar_url, enable_presence FROM puppet WHERE id=$1", id)
if row == nil { if row == nil {
return nil return nil
} }

2
go.mod
View File

@@ -25,3 +25,5 @@ require (
golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect golang.org/x/net v0.0.0-20211216030914-fe4d6282115f // indirect
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1 // indirect
) )
replace github.com/bwmarrin/discordgo => github.com/grimmy/discordgo v0.23.3-0.20220126043435-7470d1aacd64

2
go.sum
View File

@@ -28,6 +28,8 @@ github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB7
github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ= github.com/gorilla/websocket v1.4.0/go.mod h1:E7qHFY5m1UJ88s3WnNqhKjPHQ0heANvMoAMk2YaljkQ=
github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc= github.com/gorilla/websocket v1.4.2 h1:+/TMaTYc4QFitKJxsQ7Yye35DkWvkdLcvGKqM+x0Ufc=
github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE= github.com/gorilla/websocket v1.4.2/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/grimmy/discordgo v0.23.3-0.20220126043435-7470d1aacd64 h1:KrZb8UPGlmlnUx+eLQDHWnt1uPkdCUeVgUgFRFFdR1o=
github.com/grimmy/discordgo v0.23.3-0.20220126043435-7470d1aacd64/go.mod h1:c1WtWUGN6nREDmzIpyTp/iD3VYt4Fpx+bVyfBG7JE+M=
github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU=
github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= github.com/jessevdk/go-flags v0.0.0-20141203071132-1679536dcc89/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ= github.com/jrick/logrotate v1.0.0/go.mod h1:LNinyqDIJnpAur+b8yyulnQw/wDuN1+BYKlTRt3OuAQ=