Merge remote-tracking branch 'beeper/main'
This commit is contained in:
97
database/cryptostore.go
Normal file
97
database/cryptostore.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type SQLCryptoStore struct {
|
||||
*crypto.SQLCryptoStore
|
||||
UserID id.UserID
|
||||
GhostIDFormat string
|
||||
}
|
||||
|
||||
var _ crypto.Store = (*SQLCryptoStore)(nil)
|
||||
|
||||
func NewSQLCryptoStore(db *Database, userID id.UserID, ghostIDFormat string) *SQLCryptoStore {
|
||||
return &SQLCryptoStore{
|
||||
SQLCryptoStore: crypto.NewSQLCryptoStore(db.DB, db.dialect, "", "",
|
||||
[]byte("maunium.net/go/mautrix-whatsapp"),
|
||||
&cryptoLogger{db.log.Sub("CryptoStore")}),
|
||||
UserID: userID,
|
||||
GhostIDFormat: ghostIDFormat,
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID() id.DeviceID {
|
||||
var deviceID id.DeviceID
|
||||
|
||||
query := `SELECT device_id FROM crypto_account WHERE account_id=$1`
|
||||
err := store.DB.QueryRow(query, store.AccountID).Scan(&deviceID)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
store.Log.Warn("Failed to scan device ID: %v", err)
|
||||
}
|
||||
|
||||
return deviceID
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetRoomMembers(roomID id.RoomID) ([]id.UserID, error) {
|
||||
query := `
|
||||
SELECT user_id FROM mx_user_profile
|
||||
WHERE room_id=$1
|
||||
AND (membership='join' OR membership='invite')
|
||||
AND user_id<>$2
|
||||
AND user_id NOT LIKE $3
|
||||
`
|
||||
|
||||
members := []id.UserID{}
|
||||
|
||||
rows, err := store.DB.Query(query, roomID, store.UserID, store.GhostIDFormat)
|
||||
if err != nil {
|
||||
return members, err
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var userID id.UserID
|
||||
err := rows.Scan(&userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
|
||||
return members, err
|
||||
}
|
||||
|
||||
members = append(members, userID)
|
||||
}
|
||||
|
||||
return members, nil
|
||||
}
|
||||
|
||||
// TODO merge this with the one in the parent package
|
||||
type cryptoLogger struct {
|
||||
int log.Logger
|
||||
}
|
||||
|
||||
var levelTrace = log.Level{
|
||||
Name: "TRACE",
|
||||
Severity: -10,
|
||||
Color: -1,
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Error(message string, args ...interface{}) {
|
||||
c.int.Errorfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Warn(message string, args ...interface{}) {
|
||||
c.int.Warnfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Debug(message string, args ...interface{}) {
|
||||
c.int.Debugfln(message, args...)
|
||||
}
|
||||
|
||||
func (c *cryptoLogger) Trace(message string, args ...interface{}) {
|
||||
c.int.Logfln(levelTrace, message, args...)
|
||||
}
|
||||
3
database/migrations/08-add-crypto-store-to-database.sql
Normal file
3
database/migrations/08-add-crypto-store-to-database.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 0
|
||||
-- which is described as "Add crypto store to database".
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 1
|
||||
-- which is described as "Add account_id to crypto store".
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 2
|
||||
-- which is described as "Add megolm withheld data to crypto store".
|
||||
@@ -0,0 +1,3 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 3
|
||||
-- which is described as "Add cross-signing keys to crypto store".
|
||||
@@ -0,0 +1,4 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 4
|
||||
-- which is described as "Replace VARCHAR(255) with TEXT in the crypto
|
||||
-- database".
|
||||
@@ -0,0 +1,4 @@
|
||||
-- This migration is implemented in migrations.go as it comes from
|
||||
-- maunium.net/go/mautrix/crypto/sql_store_upgrade. It runs upgrade at index 5
|
||||
-- which is described as "Split last_used into last_encrypted and
|
||||
-- last_decrypted in crypto store".
|
||||
@@ -0,0 +1 @@
|
||||
ALTER TABLE portal ADD COLUMN encrypted BOOLEAN NOT NULL DEFAULT false;
|
||||
@@ -3,37 +3,18 @@ package migrations
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"sort"
|
||||
|
||||
"github.com/lopezator/migrator"
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
)
|
||||
|
||||
//go:embed *.sql
|
||||
var embeddedMigrations embed.FS
|
||||
|
||||
var (
|
||||
commonMigrations = []string{
|
||||
"01-initial.sql",
|
||||
"02-attachments.sql",
|
||||
"03-emoji.sql",
|
||||
"04-custom-puppet.sql",
|
||||
"05-additional-puppet-fields.sql",
|
||||
"07-guilds.sql",
|
||||
}
|
||||
|
||||
sqliteMigrations = []string{
|
||||
"06-remove-unique-user-constraint.sqlite.sql",
|
||||
}
|
||||
|
||||
postgresMigrations = []string{
|
||||
"06-remove-unique-user-constraint.postgres.sql",
|
||||
}
|
||||
)
|
||||
|
||||
func migrationFromFile(filename string) *migrator.Migration {
|
||||
func migrationFromFile(description, filename string) *migrator.Migration {
|
||||
return &migrator.Migration{
|
||||
Name: filename,
|
||||
Name: description,
|
||||
Func: func(tx *sql.Tx) error {
|
||||
data, err := embeddedMigrations.ReadFile(filename)
|
||||
if err != nil {
|
||||
@@ -49,31 +30,83 @@ func migrationFromFile(filename string) *migrator.Migration {
|
||||
}
|
||||
}
|
||||
|
||||
func migrationFromFileWithDialect(dialect, description, sqliteFile, postgresFile string) *migrator.Migration {
|
||||
switch dialect {
|
||||
case "sqlite3":
|
||||
return migrationFromFile(description, sqliteFile)
|
||||
case "postgres":
|
||||
return migrationFromFile(description, postgresFile)
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Run(db *sql.DB, baseLog log.Logger, dialect string) error {
|
||||
subLogger := baseLog.Sub("Migrations")
|
||||
logger := migrator.LoggerFunc(func(msg string, args ...interface{}) {
|
||||
subLogger.Infof(msg, args...)
|
||||
})
|
||||
|
||||
migrationNames := commonMigrations
|
||||
switch dialect {
|
||||
case "sqlite3":
|
||||
migrationNames = append(migrationNames, sqliteMigrations...)
|
||||
case "postgres":
|
||||
migrationNames = append(migrationNames, postgresMigrations...)
|
||||
}
|
||||
|
||||
sort.Strings(migrationNames)
|
||||
|
||||
migrations := make([]interface{}, len(migrationNames))
|
||||
for idx, name := range migrationNames {
|
||||
migrations[idx] = migrationFromFile(name)
|
||||
}
|
||||
|
||||
m, err := migrator.New(
|
||||
migrator.TableName("version"),
|
||||
migrator.WithLogger(logger),
|
||||
migrator.Migrations(migrations...),
|
||||
migrator.Migrations(
|
||||
migrationFromFile("Initial Schema", "01-initial.sql"),
|
||||
migrationFromFile("Attachments", "02-attachments.sql"),
|
||||
migrationFromFile("Emoji", "03-emoji.sql"),
|
||||
migrationFromFile("Custom Puppets", "04-custom-puppet.sql"),
|
||||
migrationFromFile(
|
||||
"Additional puppet fields",
|
||||
"05-additional-puppet-fields.sql",
|
||||
),
|
||||
migrationFromFileWithDialect(
|
||||
dialect,
|
||||
"Remove unique user constraint",
|
||||
"06-remove-unique-user-constraint.sqlite.sql",
|
||||
"06-remove-unique-user-constraint.postgres.sql",
|
||||
),
|
||||
migrationFromFile("Guild Bridging", "07-guilds.sql"),
|
||||
&migrator.Migration{
|
||||
Name: "Add crypto store to database",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[0](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Add account_id to crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[1](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Add megolm withheld data to crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[2](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Add cross-signing keys to crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[3](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Replace VARCHAR(255) with TEXT in the crypto database",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[4](tx, dialect)
|
||||
},
|
||||
},
|
||||
&migrator.Migration{
|
||||
Name: "Split last_used into last_encrypted and last_decrypted in crypto store",
|
||||
Func: func(tx *sql.Tx) error {
|
||||
return sql_store_upgrade.Upgrades[5](tx, dialect)
|
||||
},
|
||||
},
|
||||
migrationFromFile(
|
||||
"Add encryption column to portal table",
|
||||
"14-add-encrypted-column-to-portal-table.sql",
|
||||
),
|
||||
),
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@@ -19,6 +19,8 @@ type Portal struct {
|
||||
Name string
|
||||
Topic string
|
||||
|
||||
Encrypted bool
|
||||
|
||||
Avatar string
|
||||
AvatarURL id.ContentURI
|
||||
|
||||
@@ -33,7 +35,8 @@ func (p *Portal) Scan(row Scannable) *Portal {
|
||||
var typ sql.NullInt32
|
||||
|
||||
err := row.Scan(&p.Key.ChannelID, &p.Key.Receiver, &mxid, &p.Name,
|
||||
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID)
|
||||
&p.Topic, &p.Avatar, &avatarURL, &typ, &p.DMUser, &firstEventID,
|
||||
&p.Encrypted)
|
||||
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
@@ -62,12 +65,12 @@ func (p *Portal) mxidPtr() *id.RoomID {
|
||||
func (p *Portal) Insert() {
|
||||
query := "INSERT INTO portal" +
|
||||
" (channel_id, receiver, mxid, name, topic, avatar, avatar_url," +
|
||||
" type, dmuser, first_event_id)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)"
|
||||
" type, dmuser, first_event_id, encrypted)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)"
|
||||
|
||||
_, err := p.db.Exec(query, p.Key.ChannelID, p.Key.Receiver, p.mxidPtr(),
|
||||
p.Name, p.Topic, p.Avatar, p.AvatarURL.String(), p.Type, p.DMUser,
|
||||
p.FirstEventID.String())
|
||||
p.FirstEventID.String(), p.Encrypted)
|
||||
|
||||
if err != nil {
|
||||
p.log.Warnfln("Failed to insert %s: %v", p.Key, err)
|
||||
@@ -77,11 +80,12 @@ func (p *Portal) Insert() {
|
||||
func (p *Portal) Update() {
|
||||
query := "UPDATE portal SET" +
|
||||
" mxid=$1, name=$2, topic=$3, avatar=$4, avatar_url=$5, type=$6," +
|
||||
" dmuser=$7, first_event_id=$8" +
|
||||
" WHERE channel_id=$9 AND receiver=$10"
|
||||
" dmuser=$7, first_event_id=$8, encrypted=$9" +
|
||||
" WHERE channel_id=$10 AND receiver=$11"
|
||||
|
||||
_, err := p.db.Exec(query, p.mxidPtr(), p.Name, p.Topic, p.Avatar,
|
||||
p.AvatarURL.String(), p.Type, p.DMUser, p.FirstEventID.String(),
|
||||
p.Encrypted,
|
||||
p.Key.ChannelID, p.Key.Receiver)
|
||||
|
||||
if err != nil {
|
||||
|
||||
@@ -6,6 +6,12 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
portalSelect = "SELECT channel_id, receiver, mxid, name, topic, avatar," +
|
||||
" avatar_url, type, dmuser, first_event_id, encrypted" +
|
||||
" FROM portal"
|
||||
)
|
||||
|
||||
type PortalQuery struct {
|
||||
db *Database
|
||||
log log.Logger
|
||||
@@ -19,23 +25,23 @@ func (pq *PortalQuery) New() *Portal {
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAll() []*Portal {
|
||||
return pq.getAll("SELECT * FROM portal")
|
||||
return pq.getAll(portalSelect)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByID(key PortalKey) *Portal {
|
||||
return pq.get("SELECT * FROM portal WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
|
||||
return pq.get(portalSelect+" WHERE channel_id=$1 AND receiver=$2", key.ChannelID, key.Receiver)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetByMXID(mxid id.RoomID) *Portal {
|
||||
return pq.get("SELECT * FROM portal WHERE mxid=$1", mxid)
|
||||
return pq.get(portalSelect+" WHERE mxid=$1", mxid)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) GetAllByID(id string) []*Portal {
|
||||
return pq.getAll("SELECT * FROM portal WHERE receiver=$1", id)
|
||||
return pq.getAll(portalSelect+" WHERE receiver=$1", id)
|
||||
}
|
||||
|
||||
func (pq *PortalQuery) FindPrivateChats(receiver string) []*Portal {
|
||||
query := "SELECT * FROM portal WHERE receiver=$1 AND type=$2;"
|
||||
query := portalSelect + " portal WHERE receiver=$1 AND type=$2;"
|
||||
|
||||
return pq.getAll(query, receiver, discordgo.ChannelTypeDM)
|
||||
}
|
||||
|
||||
@@ -272,3 +272,33 @@ func (s *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventT
|
||||
|
||||
return s.GetPowerLevel(roomID, userID) >= s.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) []id.RoomID {
|
||||
query := `
|
||||
SELECT room_id FROM mx_user_profile
|
||||
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
|
||||
WHERE user_id=$1 AND portal.encrypted=true
|
||||
`
|
||||
|
||||
rooms := []id.RoomID{}
|
||||
|
||||
rows, err := store.db.Query(query, userID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to query shared rooms with %s: %v", userID, err)
|
||||
|
||||
return rooms
|
||||
}
|
||||
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
store.log.Warnfln("Failed to scan room ID: %v", err)
|
||||
} else {
|
||||
rooms = append(rooms, roomID)
|
||||
}
|
||||
}
|
||||
|
||||
return rooms
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user