xpra icon
Bug tracker and wiki

Ticket #2287: sql.patch

File sql.patch, 20.3 KB (added by Antoine Martin, 10 months ago)

generic sql auth

  • xpra/server/auth/mysql_auth.py

     
     1#!/usr/bin/env python
     2# This file is part of Xpra.
     3# Copyright (C) 2019 Antoine Martin <antoine@xpra.org>
     4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
     5# later version. See the file COPYING for details.
     6
     7import re
     8import sys
     9
     10from xpra.server.auth.sys_auth_base import init, log
     11from xpra.server.auth.sqlauthbase import SQLAuthenticator, DatabaseUtilBase, run_dbutil
     12assert init and log #tests will disable logging from here
     13
     14
     15def url_path_to_dict(path):
     16    pattern = (r'^'
     17               r'((?P<schema>.+?)://)?'
     18               r'((?P<user>.+?)(:(?P<password>.*?))?@)?'
     19               r'(?P<host>.*?)'
     20               r'(:(?P<port>\d+?))?'
     21               r'(?P<path>/.*?)?'
     22               r'(?P<query>[?].*?)?'
     23               r'$'
     24               )
     25    regex = re.compile(pattern)
     26    m = regex.match(path)
     27    d = m.groupdict() if m is not None else None
     28    return d
     29
     30def db_from_uri(uri):
     31    d = url_path_to_dict(uri)
     32    log("settings for uri=%s : %s", uri, d)
     33    import mysql.connector as mysql  #@UnresolvedImport
     34    db = mysql.connect(
     35        host = d.get("host", "localhost"),
     36        #port = int(d.get("port", 3306)),
     37        user = d.get("user", ""),
     38        passwd = d.get("password", ""),
     39        database = d.get("path", "").lstrip("/") or "xpra",
     40    )
     41    return db
     42
     43
     44class Authenticator(SQLAuthenticator):
     45
     46    def __init__(self, username, uri, **kwargs):
     47        SQLAuthenticator.__init__(self, username, **kwargs)
     48        self.uri = uri
     49
     50    def db_cursor(self, *sqlargs):
     51        db = db_from_uri(self.uri)
     52        cursor = db.cursor()
     53        cursor.execute(*sqlargs)
     54        #keep reference to db so it doesn't get garbage collected just yet:
     55        cursor.db = db
     56        log("db_cursor(%s)=%s", sqlargs, cursor)
     57        return cursor
     58
     59    def __repr__(self):
     60        return "mysql"
     61
     62
     63class MySQLDatabaseUtil(DatabaseUtilBase):
     64
     65    def __init__(self, uri):
     66        DatabaseUtilBase.__init__(self, uri)
     67        import mysql.connector as mysql  #@UnresolvedImport
     68        assert mysql.paramstyle=="pyformat"
     69        self.param = "%s"
     70
     71    def exec_database_sql_script(self, cursor_cb, *sqlargs):
     72        db = db_from_uri(self.uri)
     73        cursor = db.cursor()
     74        log("%s.execute%s", cursor, sqlargs)
     75        cursor.execute(*sqlargs)
     76        if cursor_cb:
     77            cursor_cb(cursor)
     78        db.commit()
     79        return cursor
     80
     81    def get_authenticator_class(self):
     82        return Authenticator
     83
     84
     85def main():
     86    return run_dbutil(MySQLDatabaseUtil, "databaseURI", sys.argv)
     87
     88if __name__ == "__main__":
     89    sys.exit(main())
  • xpra/server/auth/sqlauthbase.py

    Property changes on: xpra/server/auth/mysql_auth.py
    ___________________________________________________________________
    Added: svn:executable
    ## -0,0 +1 ##
    +*
    \ No newline at end of property
     
     1#!/usr/bin/env python
     2# This file is part of Xpra.
     3# Copyright (C) 2017-2019 Antoine Martin <antoine@xpra.org>
     4# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
     5# later version. See the file COPYING for details.
     6
     7from xpra.util import csv, parse_simple_dict
     8from xpra.os_util import getuid, getgid
     9from xpra.server.auth.sys_auth_base import SysAuthenticator, init, log
     10assert init and log #tests will disable logging from here
     11
     12
     13class SQLAuthenticator(SysAuthenticator):
     14
     15    def __init__(self, username, **kwargs):
     16        self.password_query = kwargs.pop("password_query", "SELECT password FROM users WHERE username=(%s)")
     17        self.sessions_query = kwargs.pop("sessions_query",
     18                                         "SELECT uid, gid, displays, env_options, session_options "+
     19                                         "FROM users WHERE username=(%s) AND password=(%s)")
     20        SysAuthenticator.__init__(self, username, **kwargs)
     21        self.authenticate = self.authenticate_hmac
     22
     23    def db_cursor(self, *sqlargs):
     24        raise NotImplementedError()
     25
     26    def get_passwords(self):
     27        cursor = self.db_cursor(self.password_query, (self.username,))
     28        data = cursor.fetchall()
     29        if not data:
     30            log.info("username '%s' not found in sqlauth database", self.username)
     31            return None
     32        return tuple(str(x[0]) for x in data)
     33
     34    def get_sessions(self):
     35        cursor = self.db_cursor(self.sessions_query, (self.username, self.password_used or ""))
     36        data = cursor.fetchone()
     37        if not data:
     38            return None
     39        return self.parse_session_data(data)
     40
     41    def parse_session_data(self, data):
     42        try:
     43            uid = data[0]
     44            gid = data[1]
     45            displays = []
     46            env_options = {}
     47            session_options = {}
     48            if len(data)>2:
     49                displays = [x.strip() for x in str(data[2]).split(",")]
     50            if len(data)>3:
     51                env_options = parse_simple_dict(str(data[3]), ";")
     52            if len(data)>4:
     53                session_options = parse_simple_dict(str(data[4]), ";")
     54        except Exception as e:
     55            log("parse_session_data() error on row %s", data, exc_info=True)
     56            log.error("Error: sqlauth database row parsing problem:")
     57            log.error(" %s", e)
     58            return None
     59        return uid, gid, displays, env_options, session_options
     60
     61
     62class DatabaseUtilBase(object):
     63
     64    def __init__(self, uri):
     65        self.uri = uri
     66        self.param = "?"
     67
     68    def exec_database_sql_script(self, cursor_cb, *sqlargs):
     69        raise NotImplementedError()
     70
     71    def create(self):
     72        sql = ("CREATE TABLE users ("
     73               "username VARCHAR(255) NOT NULL, "
     74               "password VARCHAR(255), "
     75               "uid VARCHAR(63), "
     76               "gid VARCHAR(63), "
     77               "displays VARCHAR(8191), "
     78               "env_options VARCHAR(8191), "
     79               "session_options VARCHAR(8191))")
     80        return self.exec_database_sql_script(None, sql)
     81
     82    def add_user(self, username, password, uid=getuid(), gid=getgid(),
     83                 displays="", env_options="", session_options=""):
     84        sql = "INSERT INTO users(username, password, uid, gid, displays, env_options, session_options) "+\
     85              "VALUES(%s, %s, %s, %s, %s, %s, %s)" % ((self.param,)*7)
     86        return self.exec_database_sql_script(None, sql,
     87                                        (username, password, uid, gid, displays, env_options, session_options))
     88
     89    def remove_user(self, username, password=None):
     90        sql = "DELETE FROM users WHERE username=%s" % self.param
     91        sqlargs = (username, )
     92        if password:
     93            sql += " AND password=%s" % self.param
     94            sqlargs = (username, password)
     95        return self.exec_database_sql_script(None, sql, sqlargs)
     96
     97    def list_users(self):
     98        fields = ("username", "password", "uid", "gid", "displays", "env_options", "session_options")
     99        def fmt(values, sizes):
     100            s = ""
     101            for i, field in enumerate(values):
     102                if i==0:
     103                    s += "|"
     104                s += ("%s" % field).rjust(sizes[i])+"|"
     105            return s
     106        def cursor_callback(cursor):
     107            rows = cursor.fetchall()
     108            if not rows:
     109                print("no rows found")
     110                cursor.close()
     111                return
     112            print("%i rows found:" % len(rows))
     113            #calculate max size for each field:
     114            sizes = [len(x)+1 for x in fields]
     115            for row in rows:
     116                for i, value in enumerate(row):
     117                    sizes[i] = max(sizes[i], len(str(value))+1)
     118            total = sum(sizes)+len(fields)+1
     119            print("-"*total)
     120            print(fmt((field.replace("_", " ") for field in fields), sizes))
     121            print("-"*total)
     122            for row in rows:
     123                print(fmt(row, sizes))
     124            cursor.close()
     125        sql = "SELECT %s FROM users" % csv(fields)
     126        self.exec_database_sql_script(cursor_callback, sql)
     127        return 0
     128
     129    def authenticate(self, username, password):
     130        auth_class = self.get_authenticator_class()
     131        a = auth_class(username, self.uri)
     132        passwords = a.get_passwords()
     133        assert passwords
     134        log("authenticate: got %i passwords", len(passwords))
     135        assert password in passwords
     136        a.password_used = password
     137        sessions = a.get_sessions()
     138        assert sessions
     139        print("success, found sessions: %s" % (sessions, ))
     140        return 0
     141
     142    def get_authenticator_class(self):
     143        raise NotImplementedError()
     144
     145
     146def run_dbutil(DatabaseUtilClass=DatabaseUtilBase, conn_str="databaseURI", argv=()):
     147    def usage(msg="invalid number of arguments"):
     148        print(msg)
     149        print("usage:")
     150        print(" %s %s create" % (argv[0], conn_str))
     151        print(" %s %s list" % (argv[0], conn_str))
     152        print(" %s %s add username password [uid, gid, displays, env_options, session_options]" % (argv[0], conn_str))
     153        print(" %s %s remove username [password]" % (argv[0], conn_str))
     154        print(" %s %s authenticate username password" % (argv[0], conn_str))
     155        return 1
     156    from xpra.platform import program_context
     157    with program_context("SQL Auth", "SQL Auth"):
     158        l = len(argv)
     159        if l<3:
     160            return usage()
     161        uri = argv[1]
     162        dbutil = DatabaseUtilClass(uri)
     163        cmd = argv[2]
     164        if cmd=="create":
     165            if l!=3:
     166                return usage()
     167            return dbutil.create()
     168        if cmd=="add":
     169            if l<5 or l>10:
     170                return usage()
     171            return dbutil.add_user(*argv[3:])
     172        if cmd=="remove":
     173            if l not in (4, 5):
     174                return usage()
     175            return dbutil.remove_user(*argv[3:])
     176        if cmd=="list":
     177            if l!=3:
     178                return usage()
     179            return dbutil.list_users()
     180        if cmd=="authenticate":
     181            if l!=5:
     182                return usage()
     183            return dbutil.authenticate(*argv[3:])
     184        return usage("invalid command '%s'" % cmd)
     185    return 0
  • xpra/server/auth/sqlite_auth.py

    Property changes on: xpra/server/auth/sqlauthbase.py
    ___________________________________________________________________
    Added: svn:executable
    ## -0,0 +1 ##
    +*
    \ No newline at end of property
     
    44# Xpra is released under the terms of the GNU GPL v2, or, at your option, any
    55# later version. See the file COPYING for details.
    66
     7import os
    78import sys
    8 import os
    99
    10 from xpra.util import parse_simple_dict, csv, engs
    11 from xpra.os_util import getuid, getgid
    12 from xpra.server.auth.sys_auth_base import SysAuthenticator, init, log, parse_uid, parse_gid
     10from xpra.util import parse_simple_dict
     11from xpra.server.auth.sys_auth_base import init, log, parse_uid, parse_gid
     12from xpra.server.auth.sqlauthbase import SQLAuthenticator, DatabaseUtilBase, run_dbutil
    1313assert init and log #tests will disable logging from here
    1414
    1515
    16 class Authenticator(SysAuthenticator):
     16class Authenticator(SQLAuthenticator):
    1717
    18     def __init__(self, username, **kwargs):
    19         filename = kwargs.pop("filename", 'sqlite.sdb')
     18    def __init__(self, username, filename="sqlite.sdb", **kwargs):
     19        SQLAuthenticator.__init__(self, username)
    2020        if filename and not os.path.isabs(filename):
    2121            exec_cwd = kwargs.get("exec_cwd", os.getcwd())
    2222            filename = os.path.join(exec_cwd, filename)
     
    2525        self.sessions_query = kwargs.pop("sessions_query",
    2626                                         "SELECT uid, gid, displays, env_options, session_options "+
    2727                                         "FROM users WHERE username=(?) AND password=(?)")
    28         SysAuthenticator.__init__(self, username, **kwargs)
    2928        self.authenticate = self.authenticate_hmac
    3029
    3130    def __repr__(self):
    3231        return "sqlite"
    3332
    34     def get_passwords(self):
     33    def db_cursor(self, *sqlargs):
    3534        if not os.path.exists(self.filename):
    3635            log.error("Error: sqlauth cannot find the database file '%s'", self.filename)
    3736            return None
    38         log("sqlauth.get_password() found database file '%s'", self.filename)
    3937        import sqlite3
    40         try:
    41             conn = sqlite3.connect(self.filename)
    42             cursor = conn.cursor()
    43             cursor.execute(self.password_query, [self.username])
    44             data = cursor.fetchall()
    45         except sqlite3.DatabaseError as e:
    46             log("get_password()", exc_info=True)
    47             log.error("Error: sqlauth database access problem:")
    48             log.error(" %s", e)
    49             return None
    50         if not data:
    51             log.info("username '%s' not found in sqlauth database", self.username)
    52             return None
    53         return tuple(str(x[0]) for x in data)
     38        db = sqlite3.connect(self.filename)
     39        cursor = db.cursor()
     40        cursor.execute(*sqlargs)
     41        #keep reference to db so it doesn't get garbage collected just yet:
     42        cursor.db = db
     43        log("db_cursor(%s)=%s", sqlargs, cursor)
     44        return cursor
    5445
    55     def get_sessions(self):
    56         import sqlite3
     46    def parse_session_data(self, data):
    5747        try:
    58             conn = sqlite3.connect(self.filename)
    59             conn.row_factory = sqlite3.Row
    60             cursor = conn.cursor()
    61             cursor.execute(self.sessions_query, [self.username, self.password_used or ""])
    62             data = cursor.fetchone()
    63         except sqlite3.DatabaseError as e:
    64             log("get_sessions()", exc_info=True)
    65             log.error("Error: sqlauth database access problem:")
    66             log.error(" %s", e)
    67             return None
    68         try:
    6948            uid = parse_uid(data["uid"])
    7049            gid = parse_gid(data["gid"])
    7150            displays = []
     
    7251            env_options = {}
    7352            session_options = {}
    7453            if data["displays"]:
    75                 displays = [x.strip() for x in str(data[2]).split(",")]
     54                displays = [x.strip() for x in str(data["displays"]).split(",")]
    7655            if data["env_options"]:
    77                 env_options = parse_simple_dict(str(data[3]), ";")
     56                env_options = parse_simple_dict(str(data["env_options"]), ";")
    7857            if data["session_options"]:
    79                 session_options=parse_simple_dict(str(data[4]), ";")
     58                session_options=parse_simple_dict(str(data["session_options"]), ";")
    8059        except Exception as e:
    8160            log("get_sessions() error on row %s", data, exc_info=True)
    8261            log.error("Error: sqlauth database row parsing problem:")
     
    8564        return uid, gid, displays, env_options, session_options
    8665
    8766
    88 def exec_database_sql_script(cursor_cb, filename, *sqlargs):
    89     log("exec_database_sql_script%s", (cursor_cb, filename, sqlargs))
    90     import sqlite3
    91     try:
    92         conn = sqlite3.connect(filename)
    93         cursor = conn.cursor()
     67class SqliteDatabaseUtil(DatabaseUtilBase):
     68
     69    def __init__(self, uri):
     70        DatabaseUtilBase.__init__(self, uri)
     71        import sqlite3
     72        assert sqlite3.paramstyle=="qmark"
     73        self.param = "?"
     74
     75    def exec_database_sql_script(self, cursor_cb, *sqlargs):
     76        import sqlite3
     77        db = sqlite3.connect(self.uri)
     78        cursor = db.cursor()
     79        log("%s.execute%s", cursor, sqlargs)
    9480        cursor.execute(*sqlargs)
    9581        if cursor_cb:
    9682            cursor_cb(cursor)
    97         conn.commit()
    98         conn.close()
    99         return 0
    100     except sqlite3.DatabaseError as e:
    101         log.error("Error: database access problem:")
    102         log.error(" %s", e)
    103         return 1
     83        db.commit()
     84        return cursor
    10485
     86    def get_authenticator_class(self):
     87        return Authenticator
    10588
    106 def create(filename):
    107     if os.path.exists(filename):
    108         log.error("Error: database file '%s' already exists", filename)
    109         return 1
    110     sql = ("CREATE TABLE users ("
    111            "username VARCHAR NOT NULL, "
    112            "password VARCHAR, "
    113            "uid VARCHAR, "
    114            "gid VARCHAR, "
    115            "displays VARCHAR, "
    116            "env_options VARCHAR, "
    117            "session_options VARCHAR)")
    118     return exec_database_sql_script(None, filename, sql)
    11989
    120 def add_user(filename, username, password, uid=getuid(), gid=getgid(), displays="", env_options="", session_options=""):
    121     sql = "INSERT INTO users(username, password, uid, gid, displays, env_options, session_options) "+\
    122           "VALUES(?, ?, ?, ?, ?, ?, ?)"
    123     return exec_database_sql_script(None, filename, sql,
    124                                     (username, password, uid, gid, displays, env_options, session_options))
     90def main():
     91    return run_dbutil(SqliteDatabaseUtil, "filename", sys.argv)
    12592
    126 def remove_user(filename, username, password=None):
    127     sql = "DELETE FROM users WHERE username=?"
    128     sqlargs = (username, )
    129     if password:
    130         sql += " AND password=?"
    131         sqlargs = (username, password)
    132     return exec_database_sql_script(None, filename, sql, sqlargs)
    133 
    134 def list_users(filename):
    135     fields = ("username", "password", "uid", "gid", "displays", "env_options", "session_options")
    136     def fmt(values, sizes):
    137         s = ""
    138         for i, field in enumerate(values):
    139             if i==0:
    140                 s += "|"
    141             s += ("%s" % field).rjust(sizes[i])+"|"
    142         return s
    143     def cursor_callback(cursor):
    144         rows = cursor.fetchall()
    145         if not rows:
    146             print("no rows found")
    147             return
    148         print("%i rows found:" % len(rows))
    149         #calculate max size for each field:
    150         sizes = [len(x)+1 for x in fields]
    151         for row in rows:
    152             for i, value in enumerate(row):
    153                 sizes[i] = max(sizes[i], len(str(value))+1)
    154         total = sum(sizes)+len(fields)+1
    155         print("-"*total)
    156         print(fmt((field.replace("_", " ") for field in fields), sizes))
    157         print("-"*total)
    158         for row in rows:
    159             print(fmt(row, sizes))
    160     sql = "SELECT %s FROM users" % csv(fields)
    161     return exec_database_sql_script(cursor_callback, filename, sql)
    162 
    163 def authenticate(filename, username, password):
    164     a = Authenticator(username, filename=filename)
    165     passwords = a.get_passwords()
    166     assert passwords
    167     assert password in passwords
    168     sessions = a.get_sessions()
    169     assert sessions
    170     print("success, found %i session%s: %s" % (len(sessions), engs(sessions), sessions))
    171     return 0
    172 
    173 def main(argv):
    174     def usage(msg="invalid number of arguments"):
    175         print(msg)
    176         print("usage:")
    177         print(" %s databasefile create" % sys.argv[0])
    178         print(" %s databasefile list" % sys.argv[0])
    179         print(" %s databasefile add username password [uid, gid, displays, env_options, session_options]" % sys.argv[0])
    180         print(" %s databasefile remove username [password]" % sys.argv[0])
    181         print(" %s databasefile authenticate username password" % sys.argv[0])
    182         return 1
    183     from xpra.platform import program_context
    184     with program_context("SQL Auth", "SQL Auth"):
    185         l = len(argv)
    186         if l<3:
    187             return usage()
    188         filename = argv[1]
    189         cmd = argv[2]
    190         if cmd=="create":
    191             if l!=3:
    192                 return usage()
    193             return create(filename)
    194         if cmd=="add":
    195             if l<5 or l>10:
    196                 return usage()
    197             return add_user(filename, *argv[3:])
    198         if cmd=="remove":
    199             if l not in (4, 5):
    200                 return usage()
    201             return remove_user(filename, *argv[3:])
    202         if cmd=="list":
    203             if l!=3:
    204                 return usage()
    205             return list_users(filename)
    206         if cmd=="authenticate":
    207             if l!=5:
    208                 return usage()
    209             return authenticate(filename, *argv[3:])
    210         return usage("invalid command '%s'" % cmd)
    211     return 0
    212 
    213 
    21493if __name__ == "__main__":
    215     sys.exit(main(sys.argv))
     94    sys.exit(main())