介绍
基于内存实现一个简易的类似mysql的数据库,末尾附代码,可优化,存储到磁盘。
创建一个Database类
Database类本身有一个Map,其中包含所有的Table对象。除了表的读写操作之外,Database类还支持事务。在这个实现中,每个事务都是在所有表上进行的。对于每个表,当前事务被表示为一个transactionRows Map,其中包含当前已插入的所有行。如果在事务中插入了新行,则这些行被添加到transactionRows,而不是直接添加到表中。当事务提交时,它们才会被真正地插入到表中。
创建一个Table类
Table类本身有一个名字,一组列名和类型,以及一个包含所有行的List。Table类有一个insertRow方法,可以将行插入到表中。此外,Table类还支持选择操作。这个实现中的选择支持选择所有列或只选择特定列。
支持事务
为了支持事务,Table类有两个Map:transactionRows和transactionLog。transactionRows Map包含当前已插入的所有行,transactionLog Map包含与当前事务相关的所有操作。在事务中,当插入新行时,新行将被添加到transactionRows和transactionLog中。事务提交时,transactionRows中的行被添加到Table的行中,transactionLog中的操作被提交。如果事务被回滚,则所有已插入的行都将从transactionRows中删除,并将transactionLog清空。
防止并发问题
另外,为了防止多个线程同时修改同一个表,Table类中包含一个读写锁,以确保只有一个线程可以写入数据。insertRow方法使用写锁,而selectRows方法使用读锁。
实现功能
代码还支持解析SQL语句,以执行相应的操作。它支持以下操作:
创建表 create table [table name] [column names]
插入行 insert into [table name] [values]
选择所有列 select * from [table name]
选择特定列 select [column names] from [table name]
开始事务 begin
提交事务 commit
回滚事务 rollback
import java.io.File;
import java.io.IOException;
import java.io.RandomAccessFile;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
public class Database {
private Map<String, Table> tables;
private ReadWriteLock tablesLock;
public Database() {
tables = new HashMap<>();
tablesLock = new ReentrantReadWriteLock();
}
public void createTable(String tableName, String[] columnNames, Class<?>[] columnTypes) {
tablesLock.writeLock().lock();
try {
// initialize table
Table table = new Table(tableName, columnNames, columnTypes);
tables.put(tableName, table);
} finally {
tablesLock.writeLock().unlock();
}
}
public void execute(String statement) {
String[] tokens = statement.trim().split("\\s+");
if (tokens.length < 1) {
throw new RuntimeException("Invalid statement");
}
String tableName="";
if(tokens.length>=2){
tableName = tokens[2];
}
if (tokens[0].equalsIgnoreCase("create")) {
if (tokens.length < 4) {
throw new RuntimeException("Invalid statement");
}
String[] columnNames = Arrays.copyOfRange(tokens, 3, tokens.length);
Class<?>[] columnTypes = new Class<?>[columnNames.length];
Arrays.fill(columnTypes, String.class);
createTable(tableName, columnNames, columnTypes);
} else if (tokens[0].equalsIgnoreCase("insert")) {
if (tokens.length < 4) {
throw new RuntimeException("Invalid statement");
}
String[] values = Arrays.copyOfRange(tokens, 3, tokens.length);
Object[] row = new Object[values.length];
for (int i = 0; i < values.length; i++) {
row[i] = values[i];
}
Table table = getTable(tableName);
table.insertRow(row);
} else if (tokens[0].equalsIgnoreCase("select")) {
if (tokens.length < 4) {
throw new RuntimeException("Invalid statement");
}
int to=2;
for (int i = 0; i < tokens.length; i++) {
if(tokens[i].equalsIgnoreCase("from")){
to=i;
}
}
String[] columnNames;
if (tokens[1].equals("*")) {
columnNames = null;
} else {
if(to==2){
throw new RuntimeException("Invalid statement");
}
columnNames = Arrays.copyOfRange(tokens, 1, to);
}
tableName=tokens[to+1];
Table table = getTable(tableName);
List<Object[]> rows = table.selectRows(columnNames);
for (Object[] row : rows) {
for (Object value : row) {
System.out.print(value + "\t");
}
System.out.println();
}
} else if (tokens[0].equalsIgnoreCase("begin")) {
beginTransaction();
} else if (tokens[0].equalsIgnoreCase("commit")) {
commitTransaction();
} else if (tokens[0].equalsIgnoreCase("rollback")) {
rollbackTransaction();
} else {
throw new RuntimeException("Invalid statement");
}
}
private Table getTable(String tableName) {
tablesLock.readLock().lock();
try {
Table table = tables.get(tableName);
if (table == null) {
table = tables.get(tableName);
}
return table;
} finally {
tablesLock.readLock().unlock();
}
}
private List<Table> getTables() {
tablesLock.readLock().lock();
try {
return new ArrayList<>(tables.values());
} finally {
tablesLock.readLock().unlock();
}
}
private void beginTransaction() {
List<Table> tables = getTables();
for (Table table : tables) {
table.beginTransaction();
}
}
private void commitTransaction() {
List<Table> tables = getTables();
for (Table table : tables) {
table.commitTransaction();
}
}
private void rollbackTransaction() {
List<Table> tables = getTables();
for (Table table : tables) {
table.rollbackTransaction();
}
}
private static class Table {
private String tableName;
private String[] columnNames;
private Class<?>[] columnTypes;
private List<Object[]> rows;
private Map<Long, Object[]> transactionRows;
private Map<Long, Operation> transactionLog;
private ReadWriteLock lock;
public Table(String tableName, String[] columnNames, Class<?>[] columnTypes) {
this.tableName = tableName;
this.columnNames = columnNames;
this.columnTypes = columnTypes;
this.rows = new ArrayList<>();
this.transactionRows = new HashMap<>();
this.transactionLog = new HashMap<>();
this.lock = new ReentrantReadWriteLock();
}
public void insertRow(Object[] row) {
lock.writeLock().lock();
try {
if (transactionRows.containsKey(Thread.currentThread().getId())) {
transactionLog.put(Thread.currentThread().getId(), new InsertOperation(row));
transactionRows.put(Thread.currentThread().getId(), row);
} else {
rows.add(row);
}
} finally {
lock.writeLock().unlock();
}
}
public List<Object[]> selectRows(String[] columnNames) {
lock.readLock().lock();
try {
List<Object[]> selectedRows = new ArrayList<>();
if (transactionRows.containsKey(Thread.currentThread().getId())) {
for (Object[] row : transactionRows.values()) {
selectedRows.add(selectColumns(row, columnNames));
}
} else {
for (Object[] row : rows) {
selectedRows.add(selectColumns(row, columnNames));
}
}
return selectedRows;
} finally {
lock.readLock().unlock();
}
}
private Object[] selectColumns(Object[] row, String[] columnNames) {
if (columnNames == null) {
return row;
} else {
Object[] selectedColumns = new Object[columnNames.length];
for (int i = 0; i < columnNames.length; i++) {
int columnIndex = getColumnIndex(columnNames[i]);
selectedColumns[i] = row[columnIndex];
}
return selectedColumns;
}
}
public void beginTransaction() {
lock.writeLock().lock();
try {
transactionRows.put(Thread.currentThread().getId(), new Object[0]);
transactionLog.put(Thread.currentThread().getId(), null);
} finally {
lock.writeLock().unlock();
}
}
public void commitTransaction() {
lock.writeLock().lock();
try {
Object[] row = transactionRows.get(Thread.currentThread().getId());
rows.add(row);
transactionRows.remove(Thread.currentThread().getId());
transactionLog.remove(Thread.currentThread().getId());
} finally {
lock.writeLock().unlock();
}
}
public void rollbackTransaction() {
lock.writeLock().lock();
try {
Operation operation = transactionLog.get(Thread.currentThread().getId());
if (operation != null) {
if (operation instanceof InsertOperation) {
Object[] row = transactionRows.get(Thread.currentThread().getId());
rows.remove(row);
}
}
transactionRows.remove(Thread.currentThread().getId());
transactionLog.remove(Thread.currentThread().getId());
} finally {
lock.writeLock().unlock();
}
}
private int getColumnIndex(String columnName) {
for (int i = 0; i < columnNames.length; i++) {
if (columnNames[i].equalsIgnoreCase(columnName)) {
return i;
}
}
throw new IllegalArgumentException("Column not found: " + columnName);
}
private abstract static class Operation {
}
private static class InsertOperation extends Operation {
private Object[] row;
public InsertOperation(Object[] row) {
this.row = row;
}
public Object[] getRow() {
return row;
}
}
}
public static void main(String[] args) {
Database table = new Database();
table.execute("create table test (id int, name varchar(50))");
table.execute("insert into test (id, name) values (1, '小明')");
table.execute("insert into test (id, name) values (2, '小红')");
table.execute("insert into test (id, name) values (3, '小李')");
table.execute("begin");
table.execute("insert into test (id, name) values (4, '小黑')");
table.execute("commit");
table.execute("begin");
table.execute("insert into test (id, name) values (5, '小白')");
table.execute("rollback");
table.execute("select * from test");
}
}