use crate::error::{Result, SqlError}; use crate::types::{ColumnDef, DataType, Value}; use sqlparser::ast::{ ColumnDef as AstColumnDef, DataType as AstDataType, Expr, Statement, }; use sqlparser::dialect::GenericDialect; use sqlparser::parser::Parser; /// Parsed SQL statement #[derive(Debug, Clone)] pub enum SqlStatement { CreateTable { table_name: String, columns: Vec, primary_key: Vec, }, DropTable { table_name: String, }, Insert { table_name: String, columns: Vec, values: Vec, }, Select { table_name: String, columns: Vec, // Empty means SELECT * where_clause: Option, }, Update { table_name: String, assignments: Vec<(String, Value)>, where_clause: Option, }, Delete { table_name: String, where_clause: Option, }, } /// WHERE clause representation #[derive(Debug, Clone)] pub enum WhereClause { Comparison { column: String, op: ComparisonOp, value: Value, }, And(Box, Box), Or(Box, Box), } #[derive(Debug, Clone)] pub enum ComparisonOp { Eq, Ne, Lt, Le, Gt, Ge, } /// Parse SQL string into SqlStatement pub fn parse_sql(sql: &str) -> Result { let dialect = GenericDialect {}; let statements = Parser::parse_sql(&dialect, sql) .map_err(|e| SqlError::ParseError(format!("Parse error: {}", e)))?; if statements.is_empty() { return Err(SqlError::ParseError("No statement found".to_string())); } if statements.len() > 1 { return Err(SqlError::ParseError( "Multiple statements not supported".to_string(), )); } let statement = &statements[0]; parse_statement(statement) } fn parse_statement(stmt: &Statement) -> Result { match stmt { Statement::CreateTable { .. } => parse_create_table(stmt), Statement::Drop { names, .. } => { if names.len() != 1 { return Err(SqlError::ParseError("Expected single table name".to_string())); } Ok(SqlStatement::DropTable { table_name: names[0].to_string(), }) } Statement::Insert { .. } => parse_insert(stmt), Statement::Query(query) => parse_select(query), Statement::Update { .. } => { Err(SqlError::ParseError("UPDATE not yet implemented".to_string())) } Statement::Delete { .. } => { Err(SqlError::ParseError("DELETE not yet implemented".to_string())) } _ => Err(SqlError::ParseError(format!( "Unsupported statement: {:?}", stmt ))), } } fn parse_create_table(stmt: &Statement) -> Result { let Statement::CreateTable { name, columns: col_defs, constraints, .. } = stmt else { return Err(SqlError::ParseError("Expected CREATE TABLE statement".to_string())); }; let table_name = name.to_string(); let mut columns = Vec::new(); let mut primary_key = Vec::new(); for column in col_defs { let col_def = parse_column_def(column)?; columns.push(col_def); } // Extract primary key from constraints for constraint in constraints { if let sqlparser::ast::TableConstraint::Unique { columns: pk_cols, is_primary: true, .. } = constraint { for pk_col in pk_cols { primary_key.push(pk_col.value.to_string()); } } } // If no explicit PRIMARY KEY constraint, check for PRIMARY KEY in column definitions if primary_key.is_empty() { for column in col_defs { for option in &column.options { if matches!(option.option, sqlparser::ast::ColumnOption::Unique { is_primary: true }) { primary_key.push(column.name.value.to_string()); break; } } } } if primary_key.is_empty() { return Err(SqlError::ParseError( "PRIMARY KEY is required".to_string(), )); } Ok(SqlStatement::CreateTable { table_name, columns, primary_key, }) } fn parse_column_def(col: &AstColumnDef) -> Result { let name = col.name.value.to_string(); let data_type = parse_data_type(&col.data_type)?; let mut nullable = true; let mut default_value = None; for option in &col.options { match &option.option { sqlparser::ast::ColumnOption::NotNull => nullable = false, sqlparser::ast::ColumnOption::Null => nullable = true, sqlparser::ast::ColumnOption::Default(expr) => { default_value = Some(parse_expr_as_value(expr)?); } _ => {} } } Ok(ColumnDef { name, data_type, nullable, default_value, }) } fn parse_data_type(dt: &AstDataType) -> Result { match dt { AstDataType::Int(_) | AstDataType::Integer(_) | AstDataType::SmallInt(_) => { Ok(DataType::Integer) } AstDataType::BigInt(_) => Ok(DataType::BigInt), AstDataType::Text | AstDataType::Varchar(_) | AstDataType::Char(_) => Ok(DataType::Text), AstDataType::Boolean => Ok(DataType::Boolean), AstDataType::Timestamp(_, _) => Ok(DataType::Timestamp), _ => Err(SqlError::ParseError(format!( "Unsupported data type: {:?}", dt ))), } } fn parse_insert(stmt: &Statement) -> Result { let Statement::Insert { table_name, columns: col_idents, source, .. } = stmt else { return Err(SqlError::ParseError("Expected INSERT statement".to_string())); }; let table_name = table_name.to_string(); let columns: Vec = col_idents.iter().map(|c| c.value.to_string()).collect(); // Extract values from the first VALUES row if let sqlparser::ast::SetExpr::Values(values) = source.body.as_ref() { if values.rows.is_empty() { return Err(SqlError::ParseError("No values provided".to_string())); } let first_row = &values.rows[0]; let mut parsed_values = Vec::new(); for expr in first_row { parsed_values.push(parse_expr_as_value(expr)?); } Ok(SqlStatement::Insert { table_name, columns, values: parsed_values, }) } else { Err(SqlError::ParseError("Expected VALUES clause".to_string())) } } fn parse_select(query: &sqlparser::ast::Query) -> Result { // For simplicity, only handle basic SELECT FROM WHERE if let sqlparser::ast::SetExpr::Select(select) = query.body.as_ref() { // Extract table name if select.from.is_empty() { return Err(SqlError::ParseError("No FROM clause".to_string())); } let table_name = match &select.from[0].relation { sqlparser::ast::TableFactor::Table { name, .. } => name.to_string(), _ => { return Err(SqlError::ParseError( "Complex FROM clauses not supported".to_string(), )) } }; // Extract columns let columns: Vec = select .projection .iter() .filter_map(|item| match item { sqlparser::ast::SelectItem::UnnamedExpr(Expr::Identifier(ident)) => { Some(ident.value.to_string()) } sqlparser::ast::SelectItem::Wildcard(_) => None, // SELECT * returns empty vec _ => None, }) .collect(); // Parse WHERE clause if present let where_clause = if let Some(expr) = &select.selection { Some(parse_where_expr(expr)?) } else { None }; Ok(SqlStatement::Select { table_name, columns, where_clause, }) } else { Err(SqlError::ParseError( "Only SELECT queries supported".to_string(), )) } } fn parse_where_expr(expr: &Expr) -> Result { match expr { Expr::BinaryOp { left, op, right } => { use sqlparser::ast::BinaryOperator; match op { BinaryOperator::Eq | BinaryOperator::NotEq | BinaryOperator::Lt | BinaryOperator::LtEq | BinaryOperator::Gt | BinaryOperator::GtEq => { let column = if let Expr::Identifier(ident) = left.as_ref() { ident.value.to_string() } else { return Err(SqlError::ParseError( "Left side of comparison must be column name".to_string(), )); }; let value = parse_expr_as_value(right)?; let op = match op { BinaryOperator::Eq => ComparisonOp::Eq, BinaryOperator::NotEq => ComparisonOp::Ne, BinaryOperator::Lt => ComparisonOp::Lt, BinaryOperator::LtEq => ComparisonOp::Le, BinaryOperator::Gt => ComparisonOp::Gt, BinaryOperator::GtEq => ComparisonOp::Ge, _ => unreachable!(), }; Ok(WhereClause::Comparison { column, op, value }) } BinaryOperator::And => { let left_clause = parse_where_expr(left)?; let right_clause = parse_where_expr(right)?; Ok(WhereClause::And( Box::new(left_clause), Box::new(right_clause), )) } BinaryOperator::Or => { let left_clause = parse_where_expr(left)?; let right_clause = parse_where_expr(right)?; Ok(WhereClause::Or( Box::new(left_clause), Box::new(right_clause), )) } _ => Err(SqlError::ParseError(format!( "Unsupported operator: {:?}", op ))), } } _ => Err(SqlError::ParseError(format!( "Unsupported WHERE expression: {:?}", expr ))), } } fn parse_expr_as_value(expr: &Expr) -> Result { match expr { Expr::Value(sqlparser::ast::Value::Number(n, _)) => { if let Ok(i) = n.parse::() { Ok(Value::Integer(i)) } else { Err(SqlError::ParseError(format!("Invalid number: {}", n))) } } Expr::Value(sqlparser::ast::Value::SingleQuotedString(s)) => Ok(Value::Text(s.clone())), Expr::Value(sqlparser::ast::Value::Boolean(b)) => Ok(Value::Boolean(*b)), Expr::Value(sqlparser::ast::Value::Null) => Ok(Value::Null), _ => Err(SqlError::ParseError(format!( "Unsupported value expression: {:?}", expr ))), } }