Source code for tlsql.tlsql.ast_nodes
"""Abstract Syntax Tree node definitions.
Supports 3 statement types:
1. TRAIN WITH.
2. PREDICT VALUE.
3. VALIDATE WITH.
"""
from dataclasses import dataclass, field
from typing import Optional, List, Any
[docs]
@dataclass
class ASTNode:
"""Base class for all AST nodes. All AST nodes inherit from this class, used for type identification and unified interface.
"""
pass
[docs]
@dataclass
class ColumnReference(ASTNode):
"""Column reference, format is 'table.column' or 'column'.
Attributes:
table: Table name.
column: Column name.
"""
table: Optional[str] = None
column: str = ""
def __str__(self) -> str:
"""Return string representation of column reference."""
if self.table:
return f"{self.table}.{self.column}"
return self.column
[docs]
@dataclass
class Expr(ASTNode):
"""Base class for all expressions.
"""
pass
[docs]
@dataclass
class LiteralExpr(Expr):
"""Literal value.
Attributes:
value: Value of literal.
value_type: Type of value, 'number' and 'string'.
"""
value: Any
value_type: str
[docs]
@dataclass
class ColumnExpr(Expr):
"""Column reference in expression.
Attributes:
column: Column reference object.
"""
column: ColumnReference
[docs]
@dataclass
class BinaryExpr(Expr):
"""Binary expression.
Supported operators:
- Comparison operators: >, <, >=, <=, ==, !=, =.
- Logical operators: AND, OR.
Attributes:
left: Left operand expression.
operator: Operator.
right: Right operand expression.
"""
left: Expr
operator: str
right: Expr
[docs]
@dataclass
class UnaryExpr(Expr):
"""Unary expression.
Supported operator:
NOT:Logical negation
Attributes:
operator: Operator.
operand: Operand expression.
"""
operator: str
operand: Expr
[docs]
@dataclass
class BetweenExpr(Expr):
"""BETWEEN expression.
Syntax: column BETWEEN value1 AND value2.
Attributes:
column: Column reference expression.
lower: Lower bound value expression.
upper: Upper bound value expression.
"""
column: Expr
lower: Expr
upper: Expr
[docs]
@dataclass
class InExpr(Expr):
"""IN expression.
Syntax: column IN (value1, value2, ...).
Attributes:
column: Column reference expression.
values: Value list.
"""
column: Expr
values: List[Expr]
[docs]
@dataclass
class WhereClause(ASTNode):
"""WHERE clause.
Attributes:
condition: Condition expression tree.
"""
condition: Expr
[docs]
@dataclass
class ColumnSelector(ASTNode):
"""Column selector in WITH clause.
Attributes:
table: Table name.
column: Column name, '*' means all columns.
"""
table: str
column: str
def __str__(self) -> str:
return f"{self.table}.{self.column}"
@property
def is_wildcard(self) -> bool:
"""Determine if it's a wildcard selector (table.*)."""
return self.column == '*'
[docs]
@dataclass
class WithClause(ASTNode):
"""WITH clause in TRAIN/VALIDATE statement.
Attributes:
selectors: Column selector list.
"""
selectors: List[ColumnSelector] = field(default_factory=list)
[docs]
@dataclass
class TablesClause(ASTNode):
"""FROM clause for multiple tables.
Syntax: FROM table1, table2, ...
Attributes:
tables: Table name list.
"""
tables: List[str] = field(default_factory=list)
[docs]
@dataclass
class TrainStatement(ASTNode):
"""TRAIN statement.
Complete syntax:
TRAIN WITH (column_selectors)
FROM table1, table2, ...
[WHERE conditions]
Attributes:
with_clause: WITH clause.
tables: Tables clause.
where: WHERE clause.
"""
with_clause: WithClause
tables: TablesClause
where: Optional[WhereClause] = None
def __repr__(self) -> str:
parts = ["TrainStatement("]
parts.append(f" with={len(self.with_clause.selectors)} selectors")
parts.append(f" tables={', '.join(self.tables.tables)}")
if self.where:
parts.append(" where=<expression>")
parts.append(")")
return "\n".join(parts)
[docs]
@dataclass
class ValidateStatement(ASTNode):
"""VALIDATE statement.
VALIDATE WITH (column_selectors)
FROM table1, table2, ...
[WHERE conditions]
Attributes:
with_clause: WITH clause.
tables: Tables clause.
where: WHERE clause.
"""
with_clause: WithClause
tables: TablesClause
where: Optional[WhereClause] = None
def __repr__(self) -> str:
parts = ["ValidateStatement("]
parts.append(f" with={len(self.with_clause.selectors)} selectors")
parts.append(f" tables={', '.join(self.tables.tables)}")
if self.where:
parts.append(" where=<expression>")
parts.append(")")
return "\n".join(parts)
[docs]
@dataclass
class PredictType(ASTNode):
"""Prediction type, CLF/REG.
Attributes:
type_name: Prediction type.
"""
type_name: str
@property
def is_classifier(self) -> bool:
return self.type_name.upper() == 'CLF'
@property
def is_regressor(self) -> bool:
return self.type_name.upper() == 'REG'
[docs]
@dataclass
class ValueClause(ASTNode):
"""VALUE clause in PREDICT statement.
Attributes:
target: Prediction target column.
predict_type: Prediction type.
"""
target: ColumnReference
predict_type: PredictType
[docs]
@dataclass
class FromClause(ASTNode):
"""FROM clause.
Attributes:
table: Table name.
"""
table: str
[docs]
@dataclass
class PredictStatement(ASTNode):
"""PREDICT statement.
PREDICT VALUE(target_column, predict_type)
FROM table
[WHERE conditions]
Attributes:
value: VALUE clause.
from_table: FROM clause.
where: WHERE clause.
"""
value: ValueClause
from_table: FromClause
where: Optional[WhereClause] = None
def __repr__(self) -> str:
parts = ["PredictStatement("]
parts.append(f" target={self.value.target}")
parts.append(f" type={self.value.predict_type.type_name}")
parts.append(f" from={self.from_table.table}")
if self.where:
parts.append(" where=<expression>")
parts.append(")")
return "\n".join(parts)
[docs]
@dataclass
class Statement(ASTNode):
"""Contains TRAIN/PREDICT/VALIDATE statements.
Attributes:
train: TRAIN statement.
predict: PREDICT statement.
validate: VALIDATE statement.
"""
train: Optional[TrainStatement] = None
predict: Optional[PredictStatement] = None
validate: Optional[ValidateStatement] = None
@property
def statement_type(self) -> str:
"""Return statement type."""
if self.train:
return "TRAIN"
elif self.predict:
return "PREDICT"
elif self.validate:
return "VALIDATE"
return "UNKNOWN"
def __repr__(self) -> str:
if self.train:
return repr(self.train)
elif self.predict:
return repr(self.predict)
elif self.validate:
return repr(self.validate)
return "Statement(empty)"