End to bridge encryption implementation
So far this is passing my basic tests, but could use some testing from people that are much more familiar with how this is supposed to work. Refs #27
This commit is contained in:
97
database/cryptostore.go
Normal file
97
database/cryptostore.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type SQLCryptoStore struct {
|
||||
*crypto.SQLCryptoStore
|
||||
UserID id.UserID
|
||||
GhostIDFormat string
|
||||
}
|
||||
|
||||
var _ crypto.Store = (*SQLCryptoStore)(nil)
|
||||
|
||||
func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
|
||||
return &SQLCryptoStore{
|
||||
SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
|
||||
[]byte("maunium.net/go/mautrix-whatsapp"),
|
||||
&cryptoLogger{db.log.Sub("CryptoStore")}),
|
||||
UserID: userID,
|
||||
GhostIDFormat: ghostIDFormat,
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID() id.DeviceID {
|
||||
var deviceID id.DeviceID
|
||||
|
||||
query := `SELECT device_id FROM crypto_account WHERE account_id=$1`
|
||||
err := store.DB.QueryRow(query, store.AccountID).Scan(&deviceID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.Log.Warn("Failed to scan device ID: %v", err)
|
||||
}
|
||||
|
||||
return deviceID
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) ([]id.UserID, error) {
|
||||
query := `
|
||||
SELECT user_id FROM mx_user_profile
|
||||
WHERE room_id=$1
|
||||
AND (membership='join' OR membership='invite')
|
||||
AND user_id<>$2
|
||||
AND user_id NOT LIKE $3
|
||||
`
|
||||
|
||||
members := []id.UserID{}
|
||||
|
||||
rows, err := store.DB.Query(query, roomID, store.UserID, store.GhostIDFormat)
|
||||
if err != nil {
|
||||
return members, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var userID id.UserID
|
||||
err := rows.Scan(&userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
|
||||
return members, err
|
||||
}
|
||||
|
||||
members = append(members, userID)
|
||||
}
|
||||
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// TODO merge this with the one in the parent package
|
||||
type cryptoLogger struct {
|
||||
int log.Logger
|
||||
}
|
||||
|
||||
var levelTrace = log.Level{
|
||||
Name: "TRACE",
|
||||
Severity: -10,
|
||||
Color: -1,
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Error(message string, args ...interface{}) {
|
||||
c.int.Errorfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
|
||||
c.int.Warnfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
|
||||
c.int.Debugfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
|
||||
c.int.Logfln(levelTrace, message, args...)
|
||||
}
|
||||
3
database/migrations/08-add-crypto-store-to-database.sql
Normal file
3
database/migrations/08-add-crypto-store-to-database.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 0
|
||||
-- which is described as "Add crypto store to database".
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 1
|
||||
-- which is described as "Add account_id to crypto store".
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 2
|
||||
-- which is described as "Add megolm withheld data to crypto store".
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 3
|
||||
-- which is described as "Add cross-signing keys to crypto store".
|
||||
@@ -0,0 +1,4 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 4
|
||||
-- which is described as "Replace VARCHAR(255) with TEXT in the crypto
|
||||
-- database".
|
||||
@@ -0,0 +1,4 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 5
|
||||
-- which is described as "Split last_used into last_encrypted and
|
||||
-- last_decrypted in crypto store".
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false;
|
||||
@@ -3,37 +3,18 @@ package migrations
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"sort"
|
||||
|
||||
"github.com/lopezator/migrator"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
)
|
||||
|
||||
//go:embed *.sql
|
||||
var embeddedMigrations embed.FS
|
||||
|
||||
var (
|
||||
commonMigrations = []string{
|
||||
"01-initial.sql",
|
||||
"02-attachments.sql",
|
||||
"03-emoji.sql",
|
||||
"04-custom-puppet.sql",
|
||||
"05-additional-puppet-fields.sql",
|
||||
"07-guilds.sql",
|
||||
}
|
||||
|
||||
sqliteMigrations = []string{
|
||||
"06-remove-unique-user-constraint.sqlite.sql",
|
||||
}
|
||||
|
||||
postgresMigrations = []string{
|
||||
"06-remove-unique-user-constraint.postgres.sql",
|
||||
}
|
||||
)
|
||||
|
||||
func migrationFromFile(filename string) *migrator.Migration {
|
||||
func migrationFromFile(description, filename string) *migrator.Migration {
|
||||
return &migrator.Migration{
|
||||
Name: filename,
|
||||
Name: description,
|
||||
Func: func(tx *sql.Tx) error {
|
||||
data, err := embeddedMigrations.ReadFile(filename)
|
||||
if err != nil {
|
||||
@@ -49,31 +30,83 @@ func migrationFromFile(filename string) *migrator.Migration {
|
||||
}
|
||||
}
|
||||
|
||||
func migrationFromFileWithDialect(dialect, description, sqliteFile, postgresFile string) *migrator.Migration {
|
||||
switch dialect {
|
||||
case "sqlite3":
|
||||
return migrationFromFile(description, sqliteFile)
|
||||
case "postgres":
|
||||
return migrationFromFile(description, postgresFile)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Run(db *sql.DB, baseLog log.Logger, dialect string) error {
|
||||
subLogger := baseLog.Sub("Migrations")
|
||||
logger := migrator.LoggerFunc(func(msg string, args ...interface{}) {
|
||||
subLogger.Infof(msg, args...)
|
||||
})
|
||||
|
||||
migrationNames := commonMigrations
|
||||
switch dialect {
|
||||
case "sqlite3":
|
||||
migrationNames = append(migrationNames, sqliteMigrations...)
|
||||
case "postgres":
|
||||
migrationNames = append(migrationNames, postgresMigrations...)
|
||||
}
|
||||
|
||||
sort.Strings(migrationNames)
|
||||
|
||||
migrations := make([]interface{}, len(migrationNames))
|
||||
for idx, name := range migrationNames {
|
||||
migrations[idx] = migrationFromFile(name)
|
||||
}
|
||||
|
||||
m, err := migrator.New(
|
||||
migrator.TableName("version"),
|
||||
migrator.WithLogger(logger),
|
||||
migrator.Migrations(migrations...),
|
||||
migrator.Migrations(
|
||||
migrationFromFile("Initial Schema", "01-initial.sql"),
|
||||
migrationFromFile("Attachments", "02-attachments.sql"),
|
||||
migrationFromFile("Emoji", "03-emoji.sql"),
|
||||
migrationFromFile("Custom Puppets", "04-custom-puppet.sql"),
|
||||
migrationFromFile(
|
||||
"Additional puppet fields",
|
||||
"05-additional-puppet-fields.sql",
|
||||
),
|
||||
migrationFromFileWithDialect(
|
||||
dialect,
|
||||
"Remove unique user constraint",
|
||||
"06-remove-unique-user-constraint.sqlite.sql",
|
||||
"06-remove-unique-user-constraint.postgres.sql",
|
||||
),
|
||||
migrationFromFile("Guild Bridging", "07-guilds.sql"),
|
||||
&migrator.Migration{
|
||||
Name: "Add crypto store to database",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[0](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Add account_id to crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[1](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Add megolm withheld data to crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[2](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Add cross-signing keys to crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[3](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Replace VARCHAR(255) with TEXT in the crypto database",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[4](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Split last_used into last_encrypted and last_decrypted in crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[5](tx, dialect)
|
||||
},
|
||||
},
|
||||
migrationFromFile(
|
||||
"Add encryption column to portal table",
|
||||
"14-add-encrypted-column-to-portal-table.sql",
|
||||
),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -19,6 +19,8 @@ type Portal struct {
|
||||
Name string
|
||||
Topic string
|
||||
|
||||
Encrypted bool
|
||||
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
|
||||
@@ -33,7 +35,8 @@ func (p *Portal) Scan(row Scannable) *Portal {
|
||||
var typ sql.NullInt32
|
||||
|
||||
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
|
||||
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID)
|
||||
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID,
|
||||
&p.Encrypted)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
@@ -62,12 +65,12 @@ func (p *Portal) mxidPtr() *id.RoomID {
|
||||
func (p *Portal) Insert() {
|
||||
query := "INSERT INTO portal" +
|
||||
" (channel_id, receiver, mxid, name, topic, avatar, avatar_url," +
|
||||
" type, dmuser, first_event_id)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||
" type, dmuser, first_event_id, encrypted)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
|
||||
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
|
||||
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser,
|
||||
p.FirstEventID.String())
|
||||
p.FirstEventID.String(), p.Encrypted)
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
|
||||
@@ -77,11 +80,12 @@ func (p *Portal) Insert() {
|
||||
func (p *Portal) Update() {
|
||||
query := "UPDATE portal SET" +
|
||||
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
|
||||
" dmuser=$7, first_event_id=$8" +
|
||||
" WHERE channel_id=$9 AND receiver=$10"
|
||||
" dmuser=$7, first_event_id=$8, encrypted=$9" +
|
||||
" WHERE channel_id=$10 AND receiver=$11"
|
||||
|
||||
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
|
||||
p.AvatarURL.String(), p.Type, p.DMUser, p.FirstEventID.String(),
|
||||
p.Encrypted,
|
||||
p.Key.ChannelID, p.Key.Receiver)
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,12 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
portalSelect = "SELECT channel_id, receiver, mxid, name, topic, avatar," +
|
||||
" avatar_url, type, dmuser, first_event_id, encrypted" +
|
||||
" FROM portal"
|
||||
)
|
||||
|
||||
type PortalQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
@@ -19,23 +25,23 @@ func (pq *PortalQuery) New() *Portal {
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAll() []*Portal {
|
||||
return pq.getAll("SELECT * FROM portal")
|
||||
return pq.getAll(portalSelect)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
|
||||
return pq.get("SELECT * FROM portal WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
|
||||
return pq.get(portalSelect+" WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
|
||||
return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
|
||||
return pq.get(portalSelect+" WHERE mxid=$1", mxid)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAllByID(id string) []*Portal {
|
||||
return pq.getAll("SELECT * FROM portal WHERE receiver=$1", id)
|
||||
return pq.getAll(portalSelect+" WHERE receiver=$1", id)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) FindPrivateChats(receiver string) []*Portal {
|
||||
query := "SELECT * FROM portal WHERE receiver=$1 AND type=$2;"
|
||||
query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
|
||||
|
||||
return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
|
||||
}
|
||||
|
||||
@@ -272,3 +272,31 @@ func (s *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventT
|
||||
|
||||
return s.GetPowerLevel(roomID, userID) >= s.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
||||
query := "SELECT room_id FROM mx_user_profile" +
|
||||
"LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id" +
|
||||
"WHERE user_id=$1 AND portal.encrypted=true"
|
||||
|
||||
rooms := []id.RoomID{}
|
||||
|
||||
rows, err := store.db.Query(query, userID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
|
||||
|
||||
return rooms
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan room ID: %v", err)
|
||||
} else {
|
||||
rooms = append(rooms, roomID)
|
||||
}
|
||||
}
|
||||
|
||||
return rooms
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user