1 /++
2     MariaDB Database Driver
3 
4     Supposed to be MySQL compatible, as well.
5 
6     ---
7     DatabaseDriver db = new MariaDBDatabaseDriver(
8         "localhost",
9         "username",
10         "password",
11         "database name",    // (optional) “Database” on the server to use initially
12         3306,               // (optional) MariaDB server port
13     );
14 
15     db.connect(); // establish database connection
16     scope(exit) db.close(); // scope guard, to close the database connection when exiting the current scope
17     ---
18 
19     $(NOTE
20         If you don’t specify an inital database during the connection setup,
21         you’ll usually want to manually select one by executing a `USE databaseName;` statement.
22     )
23  +/
24 module oceandrift.db.mariadb;
25 
26 import mysql.safe;
27 import oceandrift.db.dbal.driver;
28 import oceandrift.db.dbal.v4;
29 import std.array : appender, Appender;
30 import std.conv : to;
31 
32 @safe:
33 
34 alias DBALRow = oceandrift.db.dbal.driver.Row;
35 alias MySQLRow = mysql.safe.Row;
36 
37 /++
38     MariaDB database driver for oceandrift
39 
40     Built upon mysql-native; uses its @safe API.
41 
42     See_Also:
43         https://code.dlang.org/packages/mysql-native
44  +/
45 final class MariaDB : DatabaseDriverSpec
46 {
47 @safe:
48 
49     private
50     {
51         Connection _connection;
52 
53         string _host;
54         ushort _port;
55         string _username;
56         string _password;
57         string _database;
58     }
59 
60     /++
61         Constructor incl. connection setup
62 
63         Params:
64             host = database host (the underlying mysql-native currently only supports TCP connections, unfortunately)
65             username = MariaDB user
66             password = password of the MariaDB user
67             database = initial database to use
68             port = MariaDB server port
69      +/
70     public this(string host, string username, string password, string database = null, ushort port = 3306)
71     {
72         _host = host;
73         _port = port;
74         _username = username;
75         _password = password;
76         _database = database;
77     }
78 
79     public
80     {
81         void connect()
82         {
83             _connection = new Connection(
84                 _host,
85                 _username,
86                 _password,
87                 _database,
88                 _port
89             );
90         }
91 
92         void close()
93         {
94             _connection.close();
95         }
96 
97         bool connected()
98         {
99             return ((this._connection !is null)
100                     && !this._connection.closed);
101         }
102     }
103 
104     public
105     {
106         bool autoCommit()
107         {
108             return this._connection
109                 .queryRow("SELECT @@autocommit")
110                 .get[0] != 0;
111         }
112 
113         void autoCommit(bool enable)
114         {
115             if (enable)
116                 _connection.exec("SET autocommit=1");
117             else
118                 _connection.exec("SET autocommit=0");
119         }
120 
121         void transactionStart()
122         {
123             _connection.exec("START TRANSACTION");
124         }
125 
126         void transactionCommit()
127         {
128             _connection.exec("COMMIT");
129         }
130 
131         void transactionRollback()
132         {
133             _connection.exec("ROLLBACK");
134         }
135     }
136 
137     public
138     {
139         void execute(string sql)
140         {
141             this._connection.exec(sql);
142         }
143 
144         Statement prepare(string sql)
145         {
146             return new MariaDBStatement(this._connection, sql);
147         }
148 
149         DBValue lastInsertID()
150         {
151             return DBValue(_connection.lastInsertID());
152         }
153     }
154 
155     public  // Extras
156     {
157         ///
158         Connection getConnection()
159         {
160             return this._connection;
161         }
162     }
163 
164     public pure  // Query Compiler
165     {
166         static BuiltQuery build(const Select select)
167         {
168             auto sql = appender!string("SELECT");
169 
170             foreach (idx, se; select.columns)
171             {
172                 if (idx > 0)
173                     sql ~= ',';
174 
175                 se.toSQL(sql);
176             }
177 
178             sql ~= " FROM `";
179             sql ~= select.query.table.name.escapeIdentifier();
180             sql ~= '`';
181 
182             const query = CompilerQuery(select.query);
183             query.join.joinToSQL(sql);
184             query.where.whereToSQL(sql);
185             query.orderByToSQL(sql);
186             query.limitToSQL(sql);
187 
188             return BuiltQuery(
189                 sql.data,
190                 PlaceholdersMeta(query.where.placeholders),
191                 PreSets(query.where.preSet, query.limit.preSet, query.limit.offsetPreSet)
192             );
193         }
194 
195         static BuiltQuery build(const Update update)
196         in (update.columns.length >= 1)
197         in (CompilerQuery(update.query).join.length == 0)
198         {
199             auto sql = appender!string("UPDATE");
200             sql ~= " `";
201             sql ~= update.query.table.name.escapeIdentifier();
202             sql ~= "` SET";
203 
204             foreach (idx, value; update.columns)
205             {
206                 if (idx > 0)
207                     sql ~= ',';
208 
209                 sql ~= " `";
210                 sql ~= value.escapeIdentifier;
211                 sql ~= "` = ?";
212             }
213 
214             const query = CompilerQuery(update.query);
215             query.where.whereToSQL(sql);
216             query.orderByToSQL(sql);
217             query.limitToSQL(sql);
218 
219             return BuiltQuery(
220                 sql.data,
221                 PlaceholdersMeta(query.where.placeholders),
222                 PreSets(query.where.preSet, query.limit.preSet, query.limit.offsetPreSet)
223             );
224         }
225 
226         static BuiltQuery build(const Insert query)
227         in (
228             (query.columns.length > 1)
229             || (query.rowCount == 1)
230         )
231         {
232             auto sql = appender!string("INSERT INTO `");
233             sql ~= escapeIdentifier(query.table.name);
234 
235             if (query.columns.length == 0)
236             {
237                 sql ~= "` DEFAULT VALUES";
238             }
239             else
240             {
241                 sql ~= "` (";
242 
243                 foreach (idx, column; query.columns)
244                 {
245                     if (idx > 0)
246                         sql ~= ", ";
247 
248                     sql ~= '`';
249                     sql ~= escapeIdentifier(column);
250                     sql ~= '`';
251                 }
252 
253                 sql ~= ") VALUES";
254 
255                 for (uint n = 0; n < query.rowCount; ++n)
256                 {
257                     if (n > 0)
258                         sql ~= ",";
259 
260                     sql ~= " (";
261                     if (query.columns.length > 0)
262                     {
263                         sql ~= '?';
264 
265                         if (query.columns.length > 1)
266                             for (size_t i = 1; i < query.columns.length; ++i)
267                                 sql ~= ",?";
268                     }
269                     sql ~= ')';
270                 }
271             }
272 
273             return BuiltQuery(sql.data);
274         }
275 
276         static BuiltQuery build(const Delete delete_)
277         in (CompilerQuery(delete_.query).join.length == 0)
278         {
279             auto sql = appender!string("DELETE FROM `");
280             sql ~= delete_.query.table.name.escapeIdentifier();
281             sql ~= '`';
282 
283             const query = CompilerQuery(delete_.query);
284 
285             query.where.whereToSQL(sql);
286             query.orderByToSQL(sql);
287             query.limitToSQL(sql);
288 
289             return BuiltQuery(
290                 sql.data,
291                 PlaceholdersMeta(query.where.placeholders),
292                 PreSets(query.where.preSet, query.limit.preSet, query.limit.offsetPreSet)
293             );
294         }
295     }
296 }
297 
298 alias MariaDBDatabaseDriver = MariaDB;
299 
300 private mixin template bindImpl(T)
301 {
302     void bind(int index, const T value) @safe
303     {
304         _stmt.setArg(index, value);
305     }
306 }
307 
308 /+
309     Wrapper over mysql-native’s Prepared and ResultRange
310 
311     undocumented on purpose – shouldn’t be used directly, just stick to [oceandrift.db.dbal.Statement]
312  +/
313 final class MariaDBStatement : Statement
314 {
315 @safe:
316 
317     private
318     {
319         Connection _connection;
320         Prepared _stmt;
321         ResultRange _result;
322         DBALRow _front;
323     }
324 
325     private this(Connection connection, string sql)
326     {
327         _connection = connection;
328         _stmt = _connection.prepare(sql);
329     }
330 
331     public
332     {
333         void execute()
334         {
335             try
336             {
337                 _result = _connection.query(_stmt);
338 
339                 if (!_result.empty) // apparently result being empty can be the case
340                     _front = _result.front.mysqlToDBAL();
341             }
342             catch (MYXNoResultRecieved)
343             {
344                 // workaround because of MYXNoResultRecieved
345 
346                 // The executed query did not produce a result set.
347                 // mysql-native wants us to «use the exec functions, not query, for commands that don't produce result sets».
348 
349                 _result = typeof(_result).init;
350                 _front = null;
351             }
352         }
353 
354         void close()
355         {
356             _result.close();
357         }
358     }
359 
360     public
361     {
362         bool empty() pure nothrow
363         {
364             return _result.empty;
365         }
366 
367         DBALRow front() pure nothrow @nogc
368         {
369             return _front;
370         }
371 
372         void popFront()
373         {
374             _result.popFront();
375             if (!_result.empty)
376                 _front = _result.front.mysqlToDBAL();
377         }
378     }
379 
380     public
381     {
382         mixin bindImpl!bool;
383         mixin bindImpl!byte;
384         mixin bindImpl!ubyte;
385         mixin bindImpl!short;
386         mixin bindImpl!ushort;
387         mixin bindImpl!int;
388         mixin bindImpl!uint;
389         mixin bindImpl!long;
390         mixin bindImpl!ulong;
391         mixin bindImpl!double;
392         mixin bindImpl!string;
393         mixin bindImpl!DateTime;
394         mixin bindImpl!TimeOfDay;
395         mixin bindImpl!Date;
396         mixin bindImpl!(const(ubyte)[]);
397         mixin bindImpl!(typeof(null));
398     }
399 }
400 
401 /++
402     Creates an [oceandrift.db.dbal.driver.Row] from a [MySQLRow]
403  +/
404 DBALRow mysqlToDBAL(MySQLRow mysql)
405 {
406     auto rowData = new DBValue[](mysql.length);
407 
408     for (size_t i = 0; i < mysql.length; ++i)
409         (delegate() @trusted { rowData[i] = mysql[i].mysqlToDBAL(); })();
410 
411     return oceandrift.db.dbal.driver.Row(rowData);
412 }
413 
414 /++
415     Creates a [DBValue] from a [MySQLVal]
416  +/
417 DBValue mysqlToDBAL(MySQLVal mysql)
418 {
419     import taggedalgebraic : get, hasType;
420 
421     final switch (mysql.kind) with (MySQLVal)
422     {
423     case Kind.Blob:
424         return DBValue(mysql.get!(ubyte[]));
425     case Kind.CBlob:
426         return DBValue(mysql.get!(const(ubyte)[]));
427     case Kind.Null:
428         return DBValue(mysql.get!(null_t));
429     case Kind.Bit:
430         return DBValue(mysql.get!(bool));
431     case Kind.UByte:
432         return DBValue(mysql.get!(ubyte));
433     case Kind.Byte:
434         return DBValue(mysql.get!(byte));
435     case Kind.UShort:
436         return DBValue(mysql.get!(ushort));
437     case Kind.Short:
438         return DBValue(mysql.get!(short));
439     case Kind.UInt:
440         return DBValue(mysql.get!(uint));
441     case Kind.Int:
442         return DBValue(mysql.get!(int));
443     case Kind.ULong:
444         return DBValue(mysql.get!(ulong));
445     case Kind.Long:
446         return DBValue(mysql.get!(long));
447     case Kind.Float:
448         return DBValue(mysql.get!(float));
449     case Kind.Double:
450         return DBValue(mysql.get!(double));
451     case Kind.DateTime:
452         return DBValue(mysql.get!(DateTime));
453     case Kind.Time:
454         return DBValue(mysql.get!(TimeOfDay));
455     case Kind.Timestamp:
456         return DBValue(mysql.get!(Timestamp).rep);
457     case Kind.Date:
458         return DBValue(mysql.get!(Date));
459     case Kind.Text:
460         return DBValue(mysql.get!(string));
461     case Kind.CText:
462         return DBValue(mysql.get!(const(char)[]));
463     case Kind.BitRef:
464         return DBValue(*mysql.get!(const(bool)*));
465     case Kind.UByteRef:
466         return DBValue(*mysql.get!(const(ubyte)*));
467     case Kind.ByteRef:
468         return DBValue(*mysql.get!(const(byte)*));
469     case Kind.UShortRef:
470         return DBValue(*mysql.get!(const(ushort)*));
471     case Kind.ShortRef:
472         return DBValue(*mysql.get!(const(short)*));
473     case Kind.UIntRef:
474         return DBValue(*mysql.get!(const(uint)*));
475     case Kind.IntRef:
476         return DBValue(*mysql.get!(const(int)*));
477     case Kind.ULongRef:
478         return DBValue(*mysql.get!(const(ulong)*));
479     case Kind.LongRef:
480         return DBValue(*mysql.get!(const(long)*));
481     case Kind.FloatRef:
482         return DBValue(*mysql.get!(const(float)*));
483     case Kind.DoubleRef:
484         return DBValue(*mysql.get!(const(double)*));
485     case Kind.DateTimeRef:
486         return DBValue(*mysql.get!(const(DateTime)*));
487     case Kind.TimeRef:
488         return DBValue(*mysql.get!(const(TimeOfDay)*));
489     case Kind.DateRef:
490         return DBValue(*mysql.get!(const(Date)*));
491     case Kind.TextRef:
492         return DBValue(*mysql.get!(const(string)*));
493     case Kind.CTextRef:
494         return DBValue((*mysql.get!(const(char[])*)).dup);
495     case Kind.BlobRef:
496         return DBValue(*mysql.get!(const(ubyte[])*));
497     case Kind.TimestampRef:
498         return DBValue(mysql.get!(const(Timestamp)*).rep);
499     }
500 }
501 
502 private pure
503 {
504     void joinToSQL(const Join[] joinClause, ref Appender!string sql)
505     {
506         foreach (join; joinClause)
507         {
508             final switch (join.type) with (Join)
509             {
510             case Type.invalid:
511                 assert(0, "Join.Type.invalid");
512 
513             case Type.inner:
514                 sql ~= " JOIN `";
515                 break;
516 
517             case Type.leftOuter:
518                 sql ~= " LEFT OUTER JOIN `";
519                 break;
520 
521             case Type.rightOuter:
522                 sql ~= " RIGHT OUTER JOIN `";
523                 break;
524 
525             case Type.fullOuter:
526                 assert(false, "MariaDB does not support FULL OUTER JOINs");
527                 //sql ~= " FULL OUTER JOIN `"`;
528                 //break;
529 
530             case Type.cross:
531                 sql ~= " CROSS JOIN `";
532                 break;
533             }
534 
535             sql ~= escapeIdentifier(join.target.table.name);
536             sql ~= '`';
537 
538             if (join.target.name is null)
539                 return;
540 
541             sql ~= " ON `";
542             sql ~= escapeIdentifier(join.target.table.name);
543             sql ~= "`.`";
544             sql ~= escapeIdentifier(join.target.name);
545             sql ~= "` = `";
546 
547             if (join.source.table.name !is null)
548             {
549                 sql ~= escapeIdentifier(join.source.table.name);
550                 sql ~= "`.`";
551             }
552 
553             sql ~= escapeIdentifier(join.source.name);
554             sql ~= '`';
555         }
556     }
557 
558     void whereToSQL(const Where where, ref Appender!string sql)
559     {
560         if (where.tokens.length == 0)
561             return;
562 
563         sql ~= " WHERE";
564 
565         Token.Type prev;
566 
567         foreach (Token t; where.tokens)
568         {
569             final switch (t.type) with (Token)
570             {
571             case Type.columnTable:
572                 sql ~= " `";
573                 (delegate() @trusted { sql ~= t.data.str.escapeIdentifier(); })();
574                 sql ~= "`.";
575                 break;
576             case Type.column:
577                 if (prev != Type.columnTable)
578                     sql ~= ' ';
579                 sql ~= '`';
580                 (delegate() @trusted { sql ~= t.data.str.escapeIdentifier(); })();
581                 sql ~= '`';
582                 break;
583             case Type.placeholder:
584                 sql ~= " ?";
585                 break;
586             case Type.comparisonOperator:
587                 sql ~= t.data.op.toSQL;
588                 break;
589 
590             case Type.and:
591                 sql ~= " AND";
592                 break;
593             case Type.or:
594                 sql ~= " OR";
595                 break;
596 
597             case Type.not:
598                 sql ~= " NOT";
599                 break;
600 
601             case Type.leftParenthesis:
602                 sql ~= " (";
603                 break;
604             case Type.rightParenthesis:
605                 sql ~= " )";
606                 break;
607 
608             case Type.invalid:
609                 assert(0, "Invalid SQL token in where clause");
610             }
611 
612             prev = t.type;
613         }
614     }
615 
616     void limitToSQL(CompilerQuery q, ref Appender!string sql)
617     {
618         if (!q.limit.enabled)
619             return;
620 
621         sql ~= " LIMIT ?";
622 
623         if (!q.limit.offsetEnabled)
624             return;
625 
626         sql ~= " OFFSET ?";
627     }
628 
629     void orderByToSQL(CompilerQuery q, ref Appender!string sql)
630     {
631         if (q.orderBy.length == 0)
632             return;
633 
634         sql ~= " ORDER BY ";
635 
636         foreach (idx, OrderingTerm term; q.orderBy)
637         {
638             if (idx > 0)
639                 sql ~= ", ";
640 
641             if (term.column.table.name !is null)
642             {
643                 sql ~= '`';
644                 sql ~= escapeIdentifier(term.column.table.name);
645                 sql ~= "`.";
646             }
647             sql ~= '`';
648             sql ~= escapeIdentifier(term.column.name);
649             sql ~= '`';
650 
651             if (term.orderingSequence == OrderingSequence.desc)
652                 sql ~= " DESC";
653         }
654     }
655 
656     void toSQL(SelectExpression se, ref Appender!string sql)
657     {
658         sql ~= ' ';
659 
660         enum switchCase(string aggr) = `case ` ~ aggr ~ `: sql ~= "` ~ aggr ~ `("; break;`;
661 
662         final switch (se.aggregateFunction) with (AggregateFunction)
663         {
664             mixin(switchCase!"avg");
665             mixin(switchCase!"count");
666             mixin(switchCase!"max");
667             mixin(switchCase!"min");
668             mixin(switchCase!"sum");
669             mixin(switchCase!"group_concat");
670         case none:
671             break;
672         }
673 
674         if (se.distinct)
675             sql ~= "DISTINCT ";
676 
677         if (se.column.table.name !is null)
678         {
679             sql ~= '`';
680             sql ~= se.column.table.name;
681             sql ~= "`.";
682         }
683 
684         if (se.column.name == "*")
685         {
686             sql ~= '*';
687         }
688         else
689         {
690             sql ~= '`';
691             sql ~= se.column.name.escapeIdentifier;
692             sql ~= '`';
693         }
694 
695         if (se.aggregateFunction != AggregateFunction.none)
696             sql ~= ')';
697     }
698 
699     string toSQL(ComparisonOperator op)
700     {
701         final switch (op) with (ComparisonOperator)
702         {
703         case invalid:
704             assert(0, "Invalid comparison operator");
705 
706         case equals:
707             return " =";
708         case notEquals:
709             return " <>";
710         case lessThan:
711             return " <";
712         case greaterThan:
713             return " >";
714         case lessThanOrEquals:
715             return " <=";
716         case greaterThanOrEquals:
717             return " >=";
718         case in_:
719             return " IN";
720         case notIn:
721             return " NOT IN";
722         case like:
723             return " LIKE";
724         case notLike:
725             return " NOT LIKE";
726         case isNull:
727             return " IS NULL";
728         case isNotNull:
729             return " IS NOT NULL";
730         }
731     }
732 
733     string escapeIdentifier(string tableOrColumn)
734     {
735         import std.string : replace;
736 
737         return tableOrColumn.replace('`', "``");
738     }
739 }