ORM实现

from mysql_client import MySQLClient


class Field:
    def __init__(self, name, column_type, primary_key, default):
        self.name = name
        self.column_type = column_type
        self.primary_key = primary_key
        self.default = default


class Integerfield(Field):
    def __init__(self, name, column_type='int', primary_key=False, default=None):
        super().__init__(name, column_type, primary_key, default)


class Stringfield(Field):
    def __init__(self, name, column_type='varchar(64)', primary_key=False, default=None):
        super().__init__(name, column_type, primary_key, default)


class OrmMetaClass(type):
    def __new__(cls, class_name, class_bases, class_dict):
        if class_name == 'Models':
            return type.__new__(cls, class_name, class_bases, class_dict)

        table_name = class_dict.get('table_name', class_name)

        mappings = {}

        primary_key = None

        for key, value in class_dict.items():
            if isinstance(value, Field):
                mappings[key] = value

                if value.primary_key:
                    if primary_key:
                        raise TypeError('只能有一个主键!')

                    primary_key = value.name

        for k in mappings.keys():
            class_dict.pop(k)

        if not primary_key:
            raise TypeError('必须有一个主键!')

        class_dict['tablel_name'] = table_name
        class_dict['primary_key'] = primary_key
        class_dict['mappings'] = mappings

        return type.__new__(cls, class_name, class_bases, class_dict)


class Models(metaclass=OrmMetaClass):
    def __init__(self, **kwargs):
        for name, value in kwargs.items():
            setattr(self, name, value)

    @classmethod
    def orm_select(cls, **kwargs):
        mysql = MySQLClient()

        if not kwargs:
            sql='select * from %s' % cls.table_name
            res = mysql.my_select(sql)

        else:
            key = list(kwargs.keys())[0]
            value = kwargs.get(key)

            sql = 'select * from %s where %s=?' % (cls.table_name, key)
            sql.replace('?', '%s')

            res = mysql.my_select(sql, value)

        return res

        mysql.close()
    def orm_insert(self):
        mysql = MySQLClient()

        keys = []

        values = []

        args = []

        for k, v in self.mappings.items():
            if not v.primary_key:
                keys.append(k)

                values.append(getattr(self, v.name, v.default))

                args.append('?')

        sql = 'insert into %s(%s) values (%s)' % (
                self.table_name,
                ','.join(keys),
                ','.join(args)
                )

        sql = sql.replace('?', '%s')

        mysql.my_execute(sql, values)

        mysql.close()
    def orm_update(self):
        mysql = MySQLClient()

        keys = []

        values = []

        primary_key = None

        for k, v in self.mappings.item():
            if v.primary_key:
                primary_key = v.name + '= %s' % getattr(self, v.name)

            else:
                keys.append(v.name + '=?')
                values.append(getattr(self, v.name))
        sql = 'update %s set %s where %s' % (
                self.table_name,
                ','.join(keys),
                primary_key
                )

        sql = sql.replace('?', '%s')

        mysql.my_execute(sql, values)

        mysql.close()


class User(Models):
    user_id = Integerfield(name='user_id', primary_key=True)
    user_name = Stringfield(name='user_name')
    pwd = Stringfield(name='pwd')
    
01-09 01:59