End to bridge encryption implementation

So far this is passing my basic tests, but could use some testing from people
that are much more familiar with how this is supposed to work.

Refs #27
This commit is contained in:
Gary Kramlich
2022-04-20 06:01:26 -05:00
parent 145c0cc2cb
commit c5f58afe71
22 changed files with 848 additions and 58 deletions

97
database/cryptostore.go Normal file
View File

@@ -0,0 +1,97 @@
package database
import (
"database/sql"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/id"
)
type SQLCryptoStore struct {
*crypto.SQLCryptoStore
UserID id.UserID
GhostIDFormat string
}
var _ crypto.Store = (*SQLCryptoStore)(nil)
func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
return &SQLCryptoStore{
SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
[]byte("maunium.net/go/mautrix-whatsapp"),
&cryptoLogger{db.log.Sub("CryptoStore")}),
UserID: userID,
GhostIDFormat: ghostIDFormat,
}
}
func (store *SQLCryptoStore) FindDeviceID() id.DeviceID {
var deviceID id.DeviceID
query := `SELECT device_id FROM crypto_account WHERE account_id=$1`
err := store.DB.QueryRow(query, store.AccountID).Scan(&deviceID)
if err != nil && err != sql.ErrNoRows {
store.Log.Warn("Failed to scan device ID: %v", err)
}
return deviceID
}
func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) ([]id.UserID, error) {
query := `
SELECT user_id FROM mx_user_profile
WHERE room_id=$1
AND (membership='join' OR membership='invite')
AND user_id<>$2
AND user_id NOT LIKE $3
`
members := []id.UserID{}
rows, err := store.DB.Query(query, roomID, store.UserID, store.GhostIDFormat)
if err != nil {
return members, err
}
for rows.Next() {
var userID id.UserID
err := rows.Scan(&userID)
if err != nil {
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
return members, err
}
members = append(members, userID)
}
return members, nil
}
// TODO merge this with the one in the parent package
type cryptoLogger struct {
int log.Logger
}
var levelTrace = log.Level{
Name: "TRACE",
Severity: -10,
Color: -1,
}
func (c *cryptoLogger) Error(message string, args ...interface{}) {
c.int.Errorfln(message, args...)
}
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
c.int.Warnfln(message, args...)
}
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
c.int.Debugfln(message, args...)
}
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
c.int.Logfln(levelTrace, message, args...)
}

View File

@@ -0,0 +1,3 @@
-- This migration is implemented in migrations.go as it comes from
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 0
-- which is described as "Add crypto store to database".

View File

@@ -0,0 +1,3 @@
-- This migration is implemented in migrations.go as it comes from
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 1
-- which is described as "Add account_id to crypto store".

View File

@@ -0,0 +1,3 @@
-- This migration is implemented in migrations.go as it comes from
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 2
-- which is described as "Add megolm withheld data to crypto store".

View File

@@ -0,0 +1,3 @@
-- This migration is implemented in migrations.go as it comes from
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 3
-- which is described as "Add cross-signing keys to crypto store".

View File

@@ -0,0 +1,4 @@
-- This migration is implemented in migrations.go as it comes from
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 4
-- which is described as "Replace VARCHAR(255) with TEXT in the crypto
-- database".

View File

@@ -0,0 +1,4 @@
-- This migration is implemented in migrations.go as it comes from
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 5
-- which is described as "Split last_used into last_encrypted and
-- last_decrypted in crypto store".

View File

@@ -0,0 +1 @@
ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false;

View File

@@ -3,37 +3,18 @@ package migrations
import (
"database/sql"
"embed"
"sort"
"github.com/lopezator/migrator"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
)
//go:embed *.sql
var embeddedMigrations embed.FS
var (
commonMigrations = []string{
"01-initial.sql",
"02-attachments.sql",
"03-emoji.sql",
"04-custom-puppet.sql",
"05-additional-puppet-fields.sql",
"07-guilds.sql",
}
sqliteMigrations = []string{
"06-remove-unique-user-constraint.sqlite.sql",
}
postgresMigrations = []string{
"06-remove-unique-user-constraint.postgres.sql",
}
)
func migrationFromFile(filename string) *migrator.Migration {
func migrationFromFile(description, filename string) *migrator.Migration {
return &migrator.Migration{
Name: filename,
Name: description,
Func: func(tx *sql.Tx) error {
data, err := embeddedMigrations.ReadFile(filename)
if err != nil {
@@ -49,31 +30,83 @@ func migrationFromFile(filename string) *migrator.Migration {
}
}
func migrationFromFileWithDialect(dialect, description, sqliteFile, postgresFile string) *migrator.Migration {
switch dialect {
case "sqlite3":
return migrationFromFile(description, sqliteFile)
case "postgres":
return migrationFromFile(description, postgresFile)
default:
return nil
}
}
func Run(db *sql.DB, baseLog log.Logger, dialect string) error {
subLogger := baseLog.Sub("Migrations")
logger := migrator.LoggerFunc(func(msg string, args ...interface{}) {
subLogger.Infof(msg, args...)
})
migrationNames := commonMigrations
switch dialect {
case "sqlite3":
migrationNames = append(migrationNames, sqliteMigrations...)
case "postgres":
migrationNames = append(migrationNames, postgresMigrations...)
}
sort.Strings(migrationNames)
migrations := make([]interface{}, len(migrationNames))
for idx, name := range migrationNames {
migrations[idx] = migrationFromFile(name)
}
m, err := migrator.New(
migrator.TableName("version"),
migrator.WithLogger(logger),
migrator.Migrations(migrations...),
migrator.Migrations(
migrationFromFile("Initial Schema", "01-initial.sql"),
migrationFromFile("Attachments", "02-attachments.sql"),
migrationFromFile("Emoji", "03-emoji.sql"),
migrationFromFile("Custom Puppets", "04-custom-puppet.sql"),
migrationFromFile(
"Additional puppet fields",
"05-additional-puppet-fields.sql",
),
migrationFromFileWithDialect(
dialect,
"Remove unique user constraint",
"06-remove-unique-user-constraint.sqlite.sql",
"06-remove-unique-user-constraint.postgres.sql",
),
migrationFromFile("Guild Bridging", "07-guilds.sql"),
&migrator.Migration{
Name: "Add crypto store to database",
Func: func(tx *sql.Tx) error {
return sql_store_upgrade.Upgrades[0](tx, dialect)
},
},
&migrator.Migration{
Name: "Add account_id to crypto store",
Func: func(tx *sql.Tx) error {
return sql_store_upgrade.Upgrades[1](tx, dialect)
},
},
&migrator.Migration{
Name: "Add megolm withheld data to crypto store",
Func: func(tx *sql.Tx) error {
return sql_store_upgrade.Upgrades[2](tx, dialect)
},
},
&migrator.Migration{
Name: "Add cross-signing keys to crypto store",
Func: func(tx *sql.Tx) error {
return sql_store_upgrade.Upgrades[3](tx, dialect)
},
},
&migrator.Migration{
Name: "Replace VARCHAR(255) with TEXT in the crypto database",
Func: func(tx *sql.Tx) error {
return sql_store_upgrade.Upgrades[4](tx, dialect)
},
},
&migrator.Migration{
Name: "Split last_used into last_encrypted and last_decrypted in crypto store",
Func: func(tx *sql.Tx) error {
return sql_store_upgrade.Upgrades[5](tx, dialect)
},
},
migrationFromFile(
"Add encryption column to portal table",
"14-add-encrypted-column-to-portal-table.sql",
),
),
)
if err != nil {
return err

View File

@@ -19,6 +19,8 @@ type Portal struct {
Name string
Topic string
Encrypted bool
Avatar string
AvatarURL id.ContentURI
@@ -33,7 +35,8 @@ func (p *Portal) Scan(row Scannable) *Portal {
var typ sql.NullInt32
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID)
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID,
&p.Encrypted)
if err != nil {
if err != sql.ErrNoRows {
@@ -62,12 +65,12 @@ func (p *Portal) mxidPtr() *id.RoomID {
func (p *Portal) Insert() {
query := "INSERT INTO portal" +
" (channel_id, receiver, mxid, name, topic, avatar, avatar_url," +
" type, dmuser, first_event_id)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
" type, dmuser, first_event_id, encrypted)" +
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser,
p.FirstEventID.String())
p.FirstEventID.String(), p.Encrypted)
if err != nil {
p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
@@ -77,11 +80,12 @@ func (p *Portal) Insert() {
func (p *Portal) Update() {
query := "UPDATE portal SET" +
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
" dmuser=$7, first_event_id=$8" +
" WHERE channel_id=$9 AND receiver=$10"
" dmuser=$7, first_event_id=$8, encrypted=$9" +
" WHERE channel_id=$10 AND receiver=$11"
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
p.AvatarURL.String(), p.Type, p.DMUser, p.FirstEventID.String(),
p.Encrypted,
p.Key.ChannelID, p.Key.Receiver)
if err != nil {

View File

@@ -6,6 +6,12 @@ import (
"maunium.net/go/mautrix/id"
)
const (
portalSelect = "SELECT channel_id, receiver, mxid, name, topic, avatar," +
" avatar_url, type, dmuser, first_event_id, encrypted" +
" FROM portal"
)
type PortalQuery struct {
db *Database
log log.Logger
@@ -19,23 +25,23 @@ func (pq *PortalQuery) New() *Portal {
}
func (pq *PortalQuery) GetAll() []*Portal {
return pq.getAll("SELECT * FROM portal")
return pq.getAll(portalSelect)
}
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
return pq.get("SELECT * FROM portal WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
return pq.get(portalSelect+" WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
}
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
return pq.get(portalSelect+" WHERE mxid=$1", mxid)
}
func (pq *PortalQuery) GetAllByID(id string) []*Portal {
return pq.getAll("SELECT * FROM portal WHERE receiver=$1", id)
return pq.getAll(portalSelect+" WHERE receiver=$1", id)
}
func (pq *PortalQuery) FindPrivateChats(receiver string) []*Portal {
query := "SELECT * FROM portal WHERE receiver=$1 AND type=$2;"
query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
}

View File

@@ -272,3 +272,31 @@ func (s *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventT
return s.GetPowerLevel(roomID, userID) >= s.GetPowerLevelRequirement(roomID, eventType)
}
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) []id.RoomID {
query := "SELECT room_id FROM mx_user_profile" +
"LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id" +
"WHERE user_id=$1 AND portal.encrypted=true"
rooms := []id.RoomID{}
rows, err := store.db.Query(query, userID)
if err != nil {
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
return rooms
}
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
store.log.Warnfln("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
}
return rooms
}