package dbs

import (
	"context"
	"database/sql"
	"sync"
)

var ErrNoRows = sql.ErrNoRows
var ErrTxDone = sql.ErrTxDone

type Session interface {
	Prepare(query string) (*sql.Stmt, error)
	PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)

	Exec(query string, args ...interface{}) (sql.Result, error)
	ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)

	Query(query string, args ...interface{}) (*sql.Rows, error)
	QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)

	QueryRow(query string, args ...any) *sql.Row
	QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}

type sessionKey struct{}

type Database interface {
	Session

	Close() error

	Session(ctx context.Context) Session

	Begin() (*Tx, error)
	BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error)
}

type Transaction interface {
	Session

	Rollback() error
	Commit() error
}

func Open(driver, url string, maxOpen, maxIdle int) (*DB, error) {
	db, err := sql.Open(driver, url)
	if err != nil {
		return nil, err
	}

	if err = db.Ping(); err != nil {
		return nil, err
	}

	db.SetMaxIdleConns(maxIdle)
	db.SetMaxOpenConns(maxOpen)
	return New(db), err
}

func New(db *sql.DB) *DB {
	var ndb = &DB{}
	ndb.db = db
	ndb.mu = &sync.RWMutex{}
	ndb.stmts = make(map[string]*Stmt)
	return ndb
}

type DB struct {
	db    *sql.DB
	mu    *sync.RWMutex
	stmts map[string]*Stmt
}

func (db *DB) DB() *sql.DB {
	return db.db
}

func (db *DB) Ping() error {
	return db.db.Ping()
}

func (db *DB) PingContext(ctx context.Context) error {
	return db.db.PingContext(ctx)
}

func (db *DB) Session(ctx context.Context) Session {
	var session, exists = ctx.Value(sessionKey{}).(Session)
	if exists && session != nil {
		return session
	}
	return db
}

// Prepare 作用同 sql.DB 的 Prepare 方法。
//
// 本方法返回的 sql.Stmt 不会被缓存,不再使用之后需要调用其 Close 方法将其关闭。
func (db *DB) Prepare(query string) (*sql.Stmt, error) {
	return db.db.PrepareContext(context.Background(), query)
}

// PrepareContext 作用同 sql.DB 的 PrepareContext 方法。
//
// 本方法返回的 sql.Stmt 不会被缓存,不再使用之后需要调用其 Close 方法将其关闭。
func (db *DB) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
	return db.db.PrepareContext(ctx, query)
}

// PrepareStatement 使用参数 query 创建一个预处理语句(sql.Stmt)并将其缓存,后续可以使用 key 使用该预处理语句。
//
//	var db = dbs.New(...)
//	db.PrepareStatement(ctx, "key", "SELECT ...")
//
//	db.QueryContext(ctx, "key", "参数1", "参数2")
func (db *DB) PrepareStatement(ctx context.Context, key, query string) error {
	_, err := db.prepareStatement(ctx, key, query)
	return err
}

func (db *DB) prepareStatement(ctx context.Context, key, query string) (*sql.Stmt, error) {
	db.mu.RLock()
	if stmt, exists := db.stmts[key]; exists {
		db.mu.RUnlock()
		<-stmt.done
		if stmt.err != nil {
			return nil, stmt.err
		}
		return stmt.stmt, nil
	}
	db.mu.RUnlock()

	db.mu.Lock()
	if stmt, exists := db.stmts[key]; exists {
		db.mu.Unlock()
		<-stmt.done
		if stmt.err != nil {
			return nil, stmt.err
		}
		return stmt.stmt, nil
	}

	var stmt = &Stmt{done: make(chan struct{})}
	db.stmts[key] = stmt
	db.mu.Unlock()

	defer close(stmt.done)

	nStmt, err := db.db.PrepareContext(ctx, query)
	if err != nil {
		stmt.err = err
		db.mu.Lock()
		delete(db.stmts, key)
		db.mu.Unlock()
		return nil, err
	}
	db.mu.Lock()
	stmt.stmt = nStmt
	db.mu.Unlock()
	return nStmt, nil
}

// RevokeStatement 废弃已缓存的预处理语句(sql.Stmt)。
func (db *DB) RevokeStatement(key string) {
	db.mu.RLock()
	var stmt, exists = db.stmts[key]
	db.mu.RUnlock()

	if exists {
		<-stmt.done
		db.removeStatement(key, stmt.stmt)
	}
}

func (db *DB) removeStatement(key string, stmt *sql.Stmt) {
	db.mu.Lock()
	if stmt != nil {
		go stmt.Close()
	}
	delete(db.stmts, key)
	db.mu.Unlock()
}

// statement 使用参数 query 获取已经缓存的预处理语句(sql.Stmt)。
//
// 两种情况:
//   - 缓存中若存在,则直接返回;
//   - 缓存中不存在,则根据 query 参数创建一个预处理语句并将其缓存;
func (db *DB) statement(ctx context.Context, query string) (*sql.Stmt, error) {
	return db.prepareStatement(ctx, query, query)
}

func (db *DB) Exec(query string, args ...interface{}) (sql.Result, error) {
	return db.ExecContext(context.Background(), query, args...)
}

func (db *DB) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
	stmt, err := db.statement(ctx, query)
	if err != nil {
		return nil, err
	}
	result, err := stmt.ExecContext(ctx, args...)
	if err != nil {
		db.removeStatement(query, stmt)
	}
	return result, err
}

func (db *DB) Query(query string, args ...interface{}) (*sql.Rows, error) {
	return db.QueryContext(context.Background(), query, args...)
}

func (db *DB) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
	stmt, err := db.statement(ctx, query)
	if err != nil {
		return nil, err
	}
	rows, err := stmt.QueryContext(ctx, args...)
	if err != nil {
		db.removeStatement(query, stmt)
	}
	return rows, err
}

func (db *DB) QueryRow(query string, args ...any) *sql.Row {
	return db.QueryRowContext(context.Background(), query, args...)
}

func (db *DB) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
	stmt, err := db.statement(ctx, query)
	if err != nil {
		return nil
	}
	row := stmt.QueryRowContext(ctx, args...)
	if row.Err() != nil {
		db.removeStatement(query, stmt)
	}
	return row
}

func (db *DB) Close() error {
	return db.db.Close()
}

func (db *DB) Begin() (*Tx, error) {
	return db.BeginTx(context.Background(), nil)
}

func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) {
	tx, err := db.db.BeginTx(ctx, opts)
	if err != nil {
		return nil, err
	}
	var nTx = &Tx{}
	nTx.tx = tx
	nTx.db = db
	return nTx, nil
}

type Tx struct {
	tx *sql.Tx
	db *DB
}

func (tx *Tx) Tx() *sql.Tx {
	return tx.tx
}

// Prepare 作用同 sql.Tx 的 Prepare 方法。
//
// 本方法返回的 sql.Stmt 不会被缓存,不再使用之后需要调用其 Close 方法将其关闭。
func (tx *Tx) Prepare(query string) (*sql.Stmt, error) {
	return tx.PrepareContext(context.Background(), query)
}

// PrepareContext 作用同 sql.Tx 的 PrepareContext 方法。
//
// 本方法返回的 sql.Stmt 不会被缓存,不再使用之后需要调用其 Close 方法将其关闭。
func (tx *Tx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) {
	return tx.tx.PrepareContext(ctx, query)
}

func (tx *Tx) Statement(ctx context.Context, query string) (*sql.Stmt, error) {
	var stmt, err = tx.db.statement(ctx, query)
	if err != nil {
		return nil, err
	}
	return tx.tx.StmtContext(ctx, stmt), nil
}

func (tx *Tx) Exec(query string, args ...interface{}) (sql.Result, error) {
	return tx.ExecContext(context.Background(), query, args...)
}

func (tx *Tx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
	stmt, err := tx.Statement(ctx, query)
	if err != nil {
		return nil, err
	}
	result, err := stmt.ExecContext(ctx, args...)
	if err != nil {
		tx.db.removeStatement(query, stmt)
	}
	return result, err
}

func (tx *Tx) Query(query string, args ...interface{}) (*sql.Rows, error) {
	return tx.QueryContext(context.Background(), query, args...)
}

func (tx *Tx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
	stmt, err := tx.Statement(ctx, query)
	if err != nil {
		return nil, err
	}
	rows, err := stmt.QueryContext(ctx, args...)
	if err != nil {
		tx.db.removeStatement(query, stmt)
	}
	return rows, err
}

func (tx *Tx) QueryRow(query string, args ...any) *sql.Row {
	return tx.QueryRowContext(context.Background(), query, args...)
}

func (tx *Tx) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
	stmt, err := tx.Statement(ctx, query)
	if err != nil {
		return nil
	}
	row := stmt.QueryRowContext(ctx, args...)
	if row.Err() != nil {
		tx.db.removeStatement(query, stmt)
	}
	return row
}

func (tx *Tx) ToContext(ctx context.Context) context.Context {
	return context.WithValue(ctx, sessionKey{}, tx)
}

func (tx *Tx) Commit() error {
	return tx.tx.Commit()
}

func (tx *Tx) Rollback() error {
	return tx.tx.Rollback()
}