tidb sql 源码
tidb sql 代码
文件路径:/dumpling/export/sql.go
// Copyright 2020 PingCAP, Inc. Licensed under Apache-2.0.
package export
import (
"bytes"
"context"
"database/sql"
"encoding/json"
"fmt"
"io"
"math"
"net/url"
"strconv"
"strings"
"github.com/go-sql-driver/mysql"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/br/pkg/version"
dbconfig "github.com/pingcap/tidb/config"
tcontext "github.com/pingcap/tidb/dumpling/context"
"github.com/pingcap/tidb/dumpling/log"
"github.com/pingcap/tidb/errno"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/store/helper"
"go.uber.org/multierr"
"go.uber.org/zap"
)
const (
orderByTiDBRowID = "ORDER BY `_tidb_rowid`"
)
type listTableType int
const (
listTableByInfoSchema listTableType = iota
listTableByShowFullTables
listTableByShowTableStatus
)
// ShowDatabases shows the databases of a database server.
func ShowDatabases(db *sql.Conn) ([]string, error) {
var res oneStrColumnTable
if err := simpleQuery(db, "SHOW DATABASES", res.handleOneRow); err != nil {
return nil, err
}
return res.data, nil
}
// ShowTables shows the tables of a database, the caller should use the correct database.
func ShowTables(db *sql.Conn) ([]string, error) {
var res oneStrColumnTable
if err := simpleQuery(db, "SHOW TABLES", res.handleOneRow); err != nil {
return nil, err
}
return res.data, nil
}
// ShowCreateDatabase constructs the create database SQL for a specified database
// returns (createDatabaseSQL, error)
func ShowCreateDatabase(tctx *tcontext.Context, db *BaseConn, database string) (string, error) {
var oneRow [2]string
handleOneRow := func(rows *sql.Rows) error {
return rows.Scan(&oneRow[0], &oneRow[1])
}
query := fmt.Sprintf("SHOW CREATE DATABASE `%s`", escapeString(database))
err := db.QuerySQL(tctx, handleOneRow, func() {
oneRow[0], oneRow[1] = "", ""
}, query)
if multiErrs := multierr.Errors(err); len(multiErrs) > 0 {
for _, multiErr := range multiErrs {
if mysqlErr, ok := errors.Cause(multiErr).(*mysql.MySQLError); ok {
// Falling back to simple create statement for MemSQL/SingleStore, because of this:
// ERROR 1706 (HY000): Feature 'SHOW CREATE DATABASE' is not supported by MemSQL.
if strings.Contains(mysqlErr.Error(), "SHOW CREATE DATABASE") {
return fmt.Sprintf("CREATE DATABASE `%s`", escapeString(database)), nil
}
}
}
}
return oneRow[1], err
}
// ShowCreateTable constructs the create table SQL for a specified table
// returns (createTableSQL, error)
func ShowCreateTable(tctx *tcontext.Context, db *BaseConn, database, table string) (string, error) {
var oneRow [2]string
handleOneRow := func(rows *sql.Rows) error {
return rows.Scan(&oneRow[0], &oneRow[1])
}
query := fmt.Sprintf("SHOW CREATE TABLE `%s`.`%s`", escapeString(database), escapeString(table))
err := db.QuerySQL(tctx, handleOneRow, func() {
oneRow[0], oneRow[1] = "", ""
}, query)
if err != nil {
return "", err
}
return oneRow[1], nil
}
// ShowCreatePlacementPolicy constructs the create policy SQL for a specified table
// returns (createPolicySQL, error)
func ShowCreatePlacementPolicy(tctx *tcontext.Context, db *BaseConn, policy string) (string, error) {
var oneRow [2]string
handleOneRow := func(rows *sql.Rows) error {
return rows.Scan(&oneRow[0], &oneRow[1])
}
query := fmt.Sprintf("SHOW CREATE PLACEMENT POLICY `%s`", escapeString(policy))
err := db.QuerySQL(tctx, handleOneRow, func() {
oneRow[0], oneRow[1] = "", ""
}, query)
return oneRow[1], err
}
// ShowCreateView constructs the create view SQL for a specified view
// returns (createFakeTableSQL, createViewSQL, error)
func ShowCreateView(tctx *tcontext.Context, db *BaseConn, database, view string) (createFakeTableSQL string, createRealViewSQL string, err error) {
var fieldNames []string
handleFieldRow := func(rows *sql.Rows) error {
var oneRow [6]sql.NullString
scanErr := rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3], &oneRow[4], &oneRow[5])
if scanErr != nil {
return errors.Trace(scanErr)
}
if oneRow[0].Valid {
fieldNames = append(fieldNames, fmt.Sprintf("`%s` int", escapeString(oneRow[0].String)))
}
return nil
}
var oneRow [4]string
handleOneRow := func(rows *sql.Rows) error {
return rows.Scan(&oneRow[0], &oneRow[1], &oneRow[2], &oneRow[3])
}
var createTableSQL, createViewSQL strings.Builder
// Build createTableSQL
query := fmt.Sprintf("SHOW FIELDS FROM `%s`.`%s`", escapeString(database), escapeString(view))
err = db.QuerySQL(tctx, handleFieldRow, func() {
fieldNames = []string{}
}, query)
if err != nil {
return "", "", err
}
fmt.Fprintf(&createTableSQL, "CREATE TABLE `%s`(\n", escapeString(view))
createTableSQL.WriteString(strings.Join(fieldNames, ",\n"))
createTableSQL.WriteString("\n)ENGINE=MyISAM;\n")
// Build createViewSQL
fmt.Fprintf(&createViewSQL, "DROP TABLE IF EXISTS `%s`;\n", escapeString(view))
fmt.Fprintf(&createViewSQL, "DROP VIEW IF EXISTS `%s`;\n", escapeString(view))
query = fmt.Sprintf("SHOW CREATE VIEW `%s`.`%s`", escapeString(database), escapeString(view))
err = db.QuerySQL(tctx, handleOneRow, func() {
for i := range oneRow {
oneRow[i] = ""
}
}, query)
if err != nil {
return "", "", err
}
// The result for `show create view` SQL
// mysql> show create view v1;
// +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+
// | View | Create View | character_set_client | collation_connection |
// +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+
// | v1 | CREATE ALGORITHM=UNDEFINED DEFINER=`root`@`localhost` SQL SECURITY DEFINER VIEW `v1` (`a`) AS SELECT `t`.`a` AS `a` FROM `test`.`t` | utf8 | utf8_general_ci |
// +------+-------------------------------------------------------------------------------------------------------------------------------------+----------------------+----------------------+
SetCharset(&createViewSQL, oneRow[2], oneRow[3])
createViewSQL.WriteString(oneRow[1])
createViewSQL.WriteString(";\n")
RestoreCharset(&createViewSQL)
return createTableSQL.String(), createViewSQL.String(), nil
}
// ShowCreateSequence constructs the create sequence SQL for a specified sequence
// returns (createSequenceSQL, error)
func ShowCreateSequence(tctx *tcontext.Context, db *BaseConn, database, sequence string, conf *Config) (string, error) {
var oneRow [2]string
handleOneRow := func(rows *sql.Rows) error {
return rows.Scan(&oneRow[0], &oneRow[1])
}
var (
createSequenceSQL strings.Builder
nextNotCachedValue int64
)
query := fmt.Sprintf("SHOW CREATE SEQUENCE `%s`.`%s`", escapeString(database), escapeString(sequence))
err := db.QuerySQL(tctx, handleOneRow, func() {
oneRow[0], oneRow[1] = "", ""
}, query)
if err != nil {
return "", err
}
createSequenceSQL.WriteString(oneRow[1])
createSequenceSQL.WriteString(";\n")
switch conf.ServerInfo.ServerType {
case version.ServerTypeTiDB:
// Get next not allocated auto increment id of the whole cluster
query := fmt.Sprintf("SHOW TABLE `%s`.`%s` NEXT_ROW_ID", escapeString(database), escapeString(sequence))
results, err := db.QuerySQLWithColumns(tctx, []string{"NEXT_GLOBAL_ROW_ID", "ID_TYPE"}, query)
if err != nil {
return "", err
}
for _, oneRow := range results {
nextGlobalRowID, idType := oneRow[0], oneRow[1]
if idType == "SEQUENCE" {
nextNotCachedValue, _ = strconv.ParseInt(nextGlobalRowID, 10, 64)
}
}
fmt.Fprintf(&createSequenceSQL, "SELECT SETVAL(`%s`,%d);\n", escapeString(sequence), nextNotCachedValue)
case version.ServerTypeMariaDB:
var oneRow1 string
handleOneRow1 := func(rows *sql.Rows) error {
return rows.Scan(&oneRow1)
}
query := fmt.Sprintf("SELECT NEXT_NOT_CACHED_VALUE FROM `%s`.`%s`", escapeString(database), escapeString(sequence))
err := db.QuerySQL(tctx, handleOneRow1, func() {
oneRow1 = ""
}, query)
if err != nil {
return "", err
}
nextNotCachedValue, _ = strconv.ParseInt(oneRow1, 10, 64)
fmt.Fprintf(&createSequenceSQL, "SELECT SETVAL(`%s`,%d);\n", escapeString(sequence), nextNotCachedValue)
}
return createSequenceSQL.String(), nil
}
// SetCharset builds the set charset SQLs
func SetCharset(w *strings.Builder, characterSet, collationConnection string) {
w.WriteString("SET @PREV_CHARACTER_SET_CLIENT=@@CHARACTER_SET_CLIENT;\n")
w.WriteString("SET @PREV_CHARACTER_SET_RESULTS=@@CHARACTER_SET_RESULTS;\n")
w.WriteString("SET @PREV_COLLATION_CONNECTION=@@COLLATION_CONNECTION;\n")
fmt.Fprintf(w, "SET character_set_client = %s;\n", characterSet)
fmt.Fprintf(w, "SET character_set_results = %s;\n", characterSet)
fmt.Fprintf(w, "SET collation_connection = %s;\n", collationConnection)
}
// RestoreCharset builds the restore charset SQLs
func RestoreCharset(w io.StringWriter) {
_, _ = w.WriteString("SET character_set_client = @PREV_CHARACTER_SET_CLIENT;\n")
_, _ = w.WriteString("SET character_set_results = @PREV_CHARACTER_SET_RESULTS;\n")
_, _ = w.WriteString("SET collation_connection = @PREV_COLLATION_CONNECTION;\n")
}
// ListAllDatabasesTables lists all the databases and tables from the database
// listTableByInfoSchema list tables by table information_schema in MySQL
// listTableByShowTableStatus has better performance than listTableByInfoSchema
// listTableByShowFullTables is used in mysql8 version [8.0.3,8.0.23), more details can be found in the comments of func matchMysqlBugversion
func ListAllDatabasesTables(tctx *tcontext.Context, db *sql.Conn, databaseNames []string,
listType listTableType, tableTypes ...TableType) (DatabaseTables, error) { // revive:disable-line:flag-parameter
dbTables := DatabaseTables{}
var (
schema, table, tableTypeStr string
tableType TableType
avgRowLength uint64
err error
)
tableTypeConditions := make([]string, len(tableTypes))
for i, tableType := range tableTypes {
tableTypeConditions[i] = fmt.Sprintf("TABLE_TYPE='%s'", tableType)
}
switch listType {
case listTableByInfoSchema:
query := fmt.Sprintf("SELECT TABLE_SCHEMA,TABLE_NAME,TABLE_TYPE,AVG_ROW_LENGTH FROM INFORMATION_SCHEMA.TABLES WHERE %s", strings.Join(tableTypeConditions, " OR "))
for _, schema := range databaseNames {
dbTables[schema] = make([]*TableInfo, 0)
}
if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error {
var (
sqlAvgRowLength sql.NullInt64
err2 error
)
if err2 = rows.Scan(&schema, &table, &tableTypeStr, &sqlAvgRowLength); err != nil {
return errors.Trace(err2)
}
tableType, err2 = ParseTableType(tableTypeStr)
if err2 != nil {
return errors.Trace(err2)
}
if sqlAvgRowLength.Valid {
avgRowLength = uint64(sqlAvgRowLength.Int64)
} else {
avgRowLength = 0
}
// only append tables to schemas in databaseNames
if _, ok := dbTables[schema]; ok {
dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType})
}
return nil
}, query); err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
case listTableByShowFullTables:
for _, schema = range databaseNames {
dbTables[schema] = make([]*TableInfo, 0)
query := fmt.Sprintf("SHOW FULL TABLES FROM `%s` WHERE %s",
escapeString(schema), strings.Join(tableTypeConditions, " OR "))
if err = simpleQueryWithArgs(tctx, db, func(rows *sql.Rows) error {
var err2 error
if err2 = rows.Scan(&table, &tableTypeStr); err != nil {
return errors.Trace(err2)
}
tableType, err2 = ParseTableType(tableTypeStr)
if err2 != nil {
return errors.Trace(err2)
}
avgRowLength = 0 // can't get avgRowLength from the result of `show full tables` so hardcode to 0 here
dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType})
return nil
}, query); err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
}
default:
const queryTemplate = "SHOW TABLE STATUS FROM `%s`"
selectedTableType := make(map[TableType]struct{})
for _, tableType = range tableTypes {
selectedTableType[tableType] = struct{}{}
}
for _, schema = range databaseNames {
dbTables[schema] = make([]*TableInfo, 0)
query := fmt.Sprintf(queryTemplate, escapeString(schema))
rows, err := db.QueryContext(tctx, query)
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
results, err := GetSpecifiedColumnValuesAndClose(rows, "NAME", "ENGINE", "AVG_ROW_LENGTH", "COMMENT")
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
for _, oneRow := range results {
table, engine, avgRowLengthStr, comment := oneRow[0], oneRow[1], oneRow[2], oneRow[3]
if avgRowLengthStr != "" {
avgRowLength, err = strconv.ParseUint(avgRowLengthStr, 10, 64)
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
} else {
avgRowLength = 0
}
tableType = TableTypeBase
if engine == "" && (comment == "" || comment == TableTypeViewStr) {
tableType = TableTypeView
} else if engine == "" {
tctx.L().Warn("invalid table without engine found", zap.String("database", schema), zap.String("table", table))
continue
}
if _, ok := selectedTableType[tableType]; !ok {
continue
}
dbTables[schema] = append(dbTables[schema], &TableInfo{table, avgRowLength, tableType})
}
}
}
return dbTables, nil
}
// ListAllPlacementPolicyNames returns all placement policy names.
func ListAllPlacementPolicyNames(tctx *tcontext.Context, db *BaseConn) ([]string, error) {
var policyList []string
var policy string
const query = "select distinct policy_name from information_schema.placement_policies where policy_name is not null;"
err := db.QuerySQL(tctx, func(rows *sql.Rows) error {
err := rows.Scan(&policy)
if err != nil {
return errors.Trace(err)
}
policyList = append(policyList, policy)
return nil
}, func() {
policyList = policyList[:0]
}, query)
return policyList, errors.Annotatef(err, "sql: %s", query)
}
// SelectVersion gets the version information from the database server
func SelectVersion(db *sql.DB) (string, error) {
var versionInfo string
const query = "SELECT version()"
row := db.QueryRow(query)
err := row.Scan(&versionInfo)
if err != nil {
return "", errors.Annotatef(err, "sql: %s", query)
}
return versionInfo, nil
}
// SelectAllFromTable dumps data serialized from a specified table
func SelectAllFromTable(conf *Config, meta TableMeta, partition, orderByClause string) TableDataIR {
database, table := meta.DatabaseName(), meta.TableName()
selectedField, selectLen := meta.SelectedField(), meta.SelectedLen()
query := buildSelectQuery(database, table, selectedField, partition, buildWhereCondition(conf, ""), orderByClause)
return &tableData{
query: query,
colLen: selectLen,
}
}
func buildSelectQuery(database, table, fields, partition, where, orderByClause string) string {
var query strings.Builder
query.WriteString("SELECT ")
if fields == "" {
// If all of the columns are generated,
// we need to make sure the query is valid.
fields = "''"
}
query.WriteString(fields)
query.WriteString(" FROM `")
query.WriteString(escapeString(database))
query.WriteString("`.`")
query.WriteString(escapeString(table))
query.WriteByte('`')
if partition != "" {
query.WriteString(" PARTITION(`")
query.WriteString(escapeString(partition))
query.WriteString("`)")
}
if where != "" {
query.WriteString(" ")
query.WriteString(where)
}
if orderByClause != "" {
query.WriteString(" ")
query.WriteString(orderByClause)
}
return query.String()
}
func buildOrderByClause(tctx *tcontext.Context, conf *Config, db *BaseConn, database, table string, hasImplicitRowID bool) (string, error) { // revive:disable-line:flag-parameter
if !conf.SortByPk {
return "", nil
}
if hasImplicitRowID {
return orderByTiDBRowID, nil
}
cols, err := GetPrimaryKeyColumns(tctx, db, database, table)
if err != nil {
return "", errors.Trace(err)
}
return buildOrderByClauseString(cols), nil
}
// SelectTiDBRowID checks whether this table has _tidb_rowid column
func SelectTiDBRowID(tctx *tcontext.Context, db *BaseConn, database, table string) (bool, error) {
tiDBRowIDQuery := fmt.Sprintf("SELECT _tidb_rowid from `%s`.`%s` LIMIT 1", escapeString(database), escapeString(table))
hasImplictRowID := false
err := db.ExecSQL(tctx, func(_ sql.Result, err error) error {
if err != nil {
hasImplictRowID = false
errMsg := strings.ToLower(err.Error())
if strings.Contains(errMsg, fmt.Sprintf("%d", errno.ErrBadField)) {
return nil
}
return errors.Annotatef(err, "sql: %s", tiDBRowIDQuery)
}
hasImplictRowID = true
return nil
}, tiDBRowIDQuery)
return hasImplictRowID, err
}
// GetSuitableRows gets suitable rows for each table
func GetSuitableRows(avgRowLength uint64) uint64 {
const (
defaultRows = 200000
maxRows = 1000000
bytesPerFile = 128 * 1024 * 1024 // 128MB per file by default
)
if avgRowLength == 0 {
return defaultRows
}
estimateRows := bytesPerFile / avgRowLength
if estimateRows > maxRows {
return maxRows
}
return estimateRows
}
// GetColumnTypes gets *sql.ColumnTypes from a specified table
func GetColumnTypes(tctx *tcontext.Context, db *BaseConn, fields, database, table string) ([]*sql.ColumnType, error) {
query := fmt.Sprintf("SELECT %s FROM `%s`.`%s` LIMIT 1", fields, escapeString(database), escapeString(table))
var colTypes []*sql.ColumnType
err := db.QuerySQL(tctx, func(rows *sql.Rows) error {
var err error
colTypes, err = rows.ColumnTypes()
if err == nil {
err = rows.Close()
}
failpoint.Inject("ChaosBrokenMetaConn", func(_ failpoint.Value) {
failpoint.Return(errors.New("connection is closed"))
})
return errors.Annotatef(err, "sql: %s", query)
}, func() {
colTypes = nil
}, query)
if err != nil {
return nil, err
}
return colTypes, nil
}
// GetPrimaryKeyAndColumnTypes gets all primary columns and their types in ordinal order
func GetPrimaryKeyAndColumnTypes(tctx *tcontext.Context, conn *BaseConn, meta TableMeta) ([]string, []string, error) {
var (
colNames, colTypes []string
err error
)
colNames, err = GetPrimaryKeyColumns(tctx, conn, meta.DatabaseName(), meta.TableName())
if err != nil {
return nil, nil, err
}
colName2Type := string2Map(meta.ColumnNames(), meta.ColumnTypes())
colTypes = make([]string, len(colNames))
for i, colName := range colNames {
colTypes[i] = colName2Type[colName]
}
return colNames, colTypes, nil
}
// GetPrimaryKeyColumns gets all primary columns in ordinal order
func GetPrimaryKeyColumns(tctx *tcontext.Context, db *BaseConn, database, table string) ([]string, error) {
priKeyColsQuery := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", escapeString(database), escapeString(table))
results, err := db.QuerySQLWithColumns(tctx, []string{"KEY_NAME", "COLUMN_NAME"}, priKeyColsQuery)
if err != nil {
return nil, err
}
cols := make([]string, 0, len(results))
for _, oneRow := range results {
keyName, columnName := oneRow[0], oneRow[1]
if keyName == "PRIMARY" {
cols = append(cols, columnName)
}
}
return cols, nil
}
// getNumericIndex picks up indices according to the following priority:
// primary key > unique key with the smallest count > key with the max cardinality
// primary key with multi cols is before unique key with single col because we will sort result by primary keys
func getNumericIndex(tctx *tcontext.Context, db *BaseConn, meta TableMeta) (string, error) {
database, table := meta.DatabaseName(), meta.TableName()
colName2Type := string2Map(meta.ColumnNames(), meta.ColumnTypes())
keyQuery := fmt.Sprintf("SHOW INDEX FROM `%s`.`%s`", escapeString(database), escapeString(table))
results, err := db.QuerySQLWithColumns(tctx, []string{"NON_UNIQUE", "SEQ_IN_INDEX", "KEY_NAME", "COLUMN_NAME", "CARDINALITY"}, keyQuery)
if err != nil {
return "", err
}
type keyColumnPair struct {
colName string
count uint64
}
var (
uniqueKeyMap = map[string]keyColumnPair{} // unique key name -> key column name, unique key columns count
keyColumn string
maxCardinality int64 = -1
)
// check primary key first, then unique key
for _, oneRow := range results {
nonUnique, seqInIndex, keyName, colName, cardinality := oneRow[0], oneRow[1], oneRow[2], oneRow[3], oneRow[4]
// only try pick the first column, because the second column of pk/uk in where condition will trigger a full table scan
if seqInIndex != "1" {
if pair, ok := uniqueKeyMap[keyName]; ok {
seqInIndexInt, err := strconv.ParseUint(seqInIndex, 10, 64)
if err == nil && seqInIndexInt > pair.count {
uniqueKeyMap[keyName] = keyColumnPair{pair.colName, seqInIndexInt}
}
}
continue
}
_, numberColumn := dataTypeInt[colName2Type[colName]]
if numberColumn {
switch {
case keyName == "PRIMARY":
return colName, nil
case nonUnique == "0":
uniqueKeyMap[keyName] = keyColumnPair{colName, 1}
// pick index column with max cardinality when there is no unique index
case len(uniqueKeyMap) == 0:
cardinalityInt, err := strconv.ParseInt(cardinality, 10, 64)
if err == nil && cardinalityInt > maxCardinality {
keyColumn = colName
maxCardinality = cardinalityInt
}
}
}
}
if len(uniqueKeyMap) > 0 {
var (
minCols uint64 = math.MaxUint64
uniqueKeyColumn string
)
for _, pair := range uniqueKeyMap {
if pair.count < minCols {
uniqueKeyColumn = pair.colName
minCols = pair.count
}
}
return uniqueKeyColumn, nil
}
return keyColumn, nil
}
// FlushTableWithReadLock flush tables with read lock
func FlushTableWithReadLock(ctx context.Context, db *sql.Conn) error {
const ftwrlQuery = "FLUSH TABLES WITH READ LOCK"
_, err := db.ExecContext(ctx, ftwrlQuery)
return errors.Annotatef(err, "sql: %s", ftwrlQuery)
}
// LockTables locks table with read lock
func LockTables(ctx context.Context, db *sql.Conn, database, table string) error {
lockTableQuery := fmt.Sprintf("LOCK TABLES `%s`.`%s` READ", escapeString(database), escapeString(table))
_, err := db.ExecContext(ctx, lockTableQuery)
return errors.Annotatef(err, "sql: %s", lockTableQuery)
}
// UnlockTables unlocks all tables' lock
func UnlockTables(ctx context.Context, db *sql.Conn) error {
const unlockTableQuery = "UNLOCK TABLES"
_, err := db.ExecContext(ctx, unlockTableQuery)
return errors.Annotatef(err, "sql: %s", unlockTableQuery)
}
// ShowMasterStatus get SHOW MASTER STATUS result from database
func ShowMasterStatus(db *sql.Conn) ([]string, error) {
var oneRow []string
handleOneRow := func(rows *sql.Rows) error {
cols, err := rows.Columns()
if err != nil {
return errors.Trace(err)
}
fieldNum := len(cols)
oneRow = make([]string, fieldNum)
addr := make([]interface{}, fieldNum)
for i := range oneRow {
addr[i] = &oneRow[i]
}
return rows.Scan(addr...)
}
const showMasterStatusQuery = "SHOW MASTER STATUS"
err := simpleQuery(db, showMasterStatusQuery, handleOneRow)
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", showMasterStatusQuery)
}
return oneRow, nil
}
// GetSpecifiedColumnValueAndClose get columns' values whose name is equal to columnName and close the given rows
func GetSpecifiedColumnValueAndClose(rows *sql.Rows, columnName string) ([]string, error) {
if rows == nil {
return []string{}, nil
}
defer rows.Close()
columnName = strings.ToUpper(columnName)
var strs []string
columns, _ := rows.Columns()
addr := make([]interface{}, len(columns))
oneRow := make([]sql.NullString, len(columns))
fieldIndex := -1
for i, col := range columns {
if strings.ToUpper(col) == columnName {
fieldIndex = i
}
addr[i] = &oneRow[i]
}
if fieldIndex == -1 {
return strs, nil
}
for rows.Next() {
err := rows.Scan(addr...)
if err != nil {
return strs, errors.Trace(err)
}
if oneRow[fieldIndex].Valid {
strs = append(strs, oneRow[fieldIndex].String)
}
}
return strs, errors.Trace(rows.Err())
}
// GetSpecifiedColumnValuesAndClose get columns' values whose name is equal to columnName
func GetSpecifiedColumnValuesAndClose(rows *sql.Rows, columnName ...string) ([][]string, error) {
if rows == nil {
return [][]string{}, nil
}
defer rows.Close()
var strs [][]string
columns, err := rows.Columns()
if err != nil {
return strs, errors.Trace(err)
}
addr := make([]interface{}, len(columns))
oneRow := make([]sql.NullString, len(columns))
fieldIndexMp := make(map[int]int)
for i, col := range columns {
addr[i] = &oneRow[i]
for j, name := range columnName {
if strings.ToUpper(col) == name {
fieldIndexMp[i] = j
}
}
}
if len(fieldIndexMp) == 0 {
return strs, nil
}
for rows.Next() {
err := rows.Scan(addr...)
if err != nil {
return strs, errors.Trace(err)
}
written := false
tmpStr := make([]string, len(columnName))
for colPos, namePos := range fieldIndexMp {
if oneRow[colPos].Valid {
written = true
tmpStr[namePos] = oneRow[colPos].String
}
}
if written {
strs = append(strs, tmpStr)
}
}
return strs, errors.Trace(rows.Err())
}
// GetPdAddrs gets PD address from TiDB
func GetPdAddrs(tctx *tcontext.Context, db *sql.DB) ([]string, error) {
const query = "SELECT * FROM information_schema.cluster_info where type = 'pd';"
rows, err := db.QueryContext(tctx, query)
if err != nil {
return []string{}, errors.Annotatef(err, "sql: %s", query)
}
pdAddrs, err := GetSpecifiedColumnValueAndClose(rows, "STATUS_ADDRESS")
return pdAddrs, errors.Annotatef(err, "sql: %s", query)
}
// GetTiDBDDLIDs gets DDL IDs from TiDB
func GetTiDBDDLIDs(tctx *tcontext.Context, db *sql.DB) ([]string, error) {
const query = "SELECT * FROM information_schema.tidb_servers_info;"
rows, err := db.QueryContext(tctx, query)
if err != nil {
return []string{}, errors.Annotatef(err, "sql: %s", query)
}
ddlIDs, err := GetSpecifiedColumnValueAndClose(rows, "DDL_ID")
return ddlIDs, errors.Annotatef(err, "sql: %s", query)
}
// getTiDBConfig gets tidb config from TiDB server
// @@tidb_config details doc https://docs.pingcap.com/tidb/stable/system-variables#tidb_config
// this variable exists at least from v2.0.0, so this works in most existing tidb instances
func getTiDBConfig(db *sql.Conn) (dbconfig.Config, error) {
const query = "SELECT @@tidb_config;"
var (
tidbConfig dbconfig.Config
tidbConfigBytes []byte
)
row := db.QueryRowContext(context.Background(), query)
err := row.Scan(&tidbConfigBytes)
if err != nil {
return tidbConfig, errors.Annotatef(err, "sql: %s", query)
}
err = json.Unmarshal(tidbConfigBytes, &tidbConfig)
return tidbConfig, errors.Annotatef(err, "sql: %s", query)
}
// CheckTiDBWithTiKV use sql to check whether current TiDB has TiKV
func CheckTiDBWithTiKV(db *sql.DB) (bool, error) {
conn, err := db.Conn(context.Background())
if err == nil {
defer func() {
_ = conn.Close()
}()
tidbConfig, err := getTiDBConfig(conn)
if err == nil {
return tidbConfig.Store == "tikv", nil
}
}
var count int
const query = "SELECT COUNT(1) as c FROM MYSQL.TiDB WHERE VARIABLE_NAME='tikv_gc_safe_point'"
row := db.QueryRow(query)
err = row.Scan(&count)
if err != nil {
// still return true here. Because sometimes users may not have privileges for MySQL.TiDB database
// In most production cases TiDB has TiKV
return true, errors.Annotatef(err, "sql: %s", query)
}
return count > 0, nil
}
// CheckIfSeqExists use sql to check whether sequence exists
func CheckIfSeqExists(db *sql.Conn) (bool, error) {
var count int
const query = "SELECT COUNT(1) as c FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_TYPE='SEQUENCE'"
row := db.QueryRowContext(context.Background(), query)
err := row.Scan(&count)
if err != nil {
return false, errors.Annotatef(err, "sql: %s", query)
}
return count > 0, nil
}
// CheckTiDBEnableTableLock use sql variable to check whether current TiDB has TiKV
func CheckTiDBEnableTableLock(db *sql.Conn) (bool, error) {
tidbConfig, err := getTiDBConfig(db)
if err != nil {
return false, err
}
return tidbConfig.EnableTableLock, nil
}
func getSnapshot(db *sql.Conn) (string, error) {
str, err := ShowMasterStatus(db)
if err != nil {
return "", err
}
return str[snapshotFieldIndex], nil
}
func isUnknownSystemVariableErr(err error) bool {
return strings.Contains(err.Error(), "Unknown system variable")
}
// resetDBWithSessionParams will return a new sql.DB as a replacement for input `db` with new session parameters.
// If returned error is nil, the input `db` will be closed.
func resetDBWithSessionParams(tctx *tcontext.Context, db *sql.DB, dsn string, params map[string]interface{}) (*sql.DB, error) {
support := make(map[string]interface{})
for k, v := range params {
var pv interface{}
if str, ok := v.(string); ok {
if pvi, err := strconv.ParseInt(str, 10, 64); err == nil {
pv = pvi
} else if pvf, err := strconv.ParseFloat(str, 64); err == nil {
pv = pvf
} else {
pv = str
}
} else {
pv = v
}
s := fmt.Sprintf("SET SESSION %s = ?", k)
_, err := db.ExecContext(tctx, s, pv)
if err != nil {
if isUnknownSystemVariableErr(err) {
tctx.L().Info("session variable is not supported by db", zap.String("variable", k), zap.Reflect("value", v))
continue
}
return nil, errors.Trace(err)
}
support[k] = pv
}
for k, v := range support {
var s string
// Wrap string with quote to handle string with space. For example, '2020-10-20 13:41:40'
// For --params argument, quote doesn't matter because it doesn't affect the actual value
if str, ok := v.(string); ok {
s = wrapStringWith(str, "'")
} else {
s = fmt.Sprintf("%v", v)
}
dsn += fmt.Sprintf("&%s=%s", k, url.QueryEscape(s))
}
db.Close()
newDB, err := sql.Open("mysql", dsn)
if err == nil {
// ping to make sure all session parameters are set correctly
err = newDB.PingContext(tctx)
if err != nil {
newDB.Close()
}
}
return newDB, errors.Trace(err)
}
func createConnWithConsistency(ctx context.Context, db *sql.DB, repeatableRead bool) (*sql.Conn, error) {
conn, err := db.Conn(ctx)
if err != nil {
return nil, errors.Trace(err)
}
var query string
if repeatableRead {
query = "SET SESSION TRANSACTION ISOLATION LEVEL REPEATABLE READ"
_, err = conn.ExecContext(ctx, query)
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
}
query = "START TRANSACTION /*!40108 WITH CONSISTENT SNAPSHOT */"
_, err = conn.ExecContext(ctx, query)
if err != nil {
// Some MySQL Compatible databases like Vitess and MemSQL/SingleStore
// are newer than 4.1.8 (the version comment) but don't actually support
// `WITH CONSISTENT SNAPSHOT`. So retry without that if the statement fails.
query = "START TRANSACTION"
_, err = conn.ExecContext(ctx, query)
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
}
return conn, nil
}
// buildSelectField returns the selecting fields' string(joined by comma(`,`)),
// and the number of writable fields.
func buildSelectField(tctx *tcontext.Context, db *BaseConn, dbName, tableName string, completeInsert bool) (string, int, error) { // revive:disable-line:flag-parameter
query := fmt.Sprintf("SHOW COLUMNS FROM `%s`.`%s`", escapeString(dbName), escapeString(tableName))
results, err := db.QuerySQLWithColumns(tctx, []string{"FIELD", "EXTRA"}, query)
if err != nil {
return "", 0, err
}
availableFields := make([]string, 0)
hasGenerateColumn := false
for _, oneRow := range results {
fieldName, extra := oneRow[0], oneRow[1]
switch extra {
case "STORED GENERATED", "VIRTUAL GENERATED":
hasGenerateColumn = true
continue
}
availableFields = append(availableFields, wrapBackTicks(escapeString(fieldName)))
}
if completeInsert || hasGenerateColumn {
return strings.Join(availableFields, ","), len(availableFields), nil
}
return "*", len(availableFields), nil
}
func buildWhereClauses(handleColNames []string, handleVals [][]string) []string {
if len(handleColNames) == 0 || len(handleVals) == 0 {
return nil
}
quotaCols := make([]string, len(handleColNames))
for i, s := range handleColNames {
quotaCols[i] = fmt.Sprintf("`%s`", escapeString(s))
}
where := make([]string, 0, len(handleVals)+1)
buf := &bytes.Buffer{}
buildCompareClause(buf, quotaCols, handleVals[0], less, false)
where = append(where, buf.String())
buf.Reset()
for i := 1; i < len(handleVals); i++ {
low, up := handleVals[i-1], handleVals[i]
buildBetweenClause(buf, quotaCols, low, up)
where = append(where, buf.String())
buf.Reset()
}
buildCompareClause(buf, quotaCols, handleVals[len(handleVals)-1], greater, true)
where = append(where, buf.String())
buf.Reset()
return where
}
// return greater than TableRangeScan where clause
// the result doesn't contain brackets
const (
greater = '>'
less = '<'
equal = '='
)
// buildCompareClause build clause with specified bounds. Usually we will use the following two conditions:
// (compare, writeEqual) == (less, false), return quotaCols < bound clause. In other words, (-inf, bound)
// (compare, writeEqual) == (greater, true), return quotaCols >= bound clause. In other words, [bound, +inf)
func buildCompareClause(buf *bytes.Buffer, quotaCols []string, bound []string, compare byte, writeEqual bool) { // revive:disable-line:flag-parameter
for i, col := range quotaCols {
if i > 0 {
buf.WriteString("or(")
}
for j := 0; j < i; j++ {
buf.WriteString(quotaCols[j])
buf.WriteByte(equal)
buf.WriteString(bound[j])
buf.WriteString(" and ")
}
buf.WriteString(col)
buf.WriteByte(compare)
if writeEqual && i == len(quotaCols)-1 {
buf.WriteByte(equal)
}
buf.WriteString(bound[i])
if i > 0 {
buf.WriteByte(')')
} else if i != len(quotaCols)-1 {
buf.WriteByte(' ')
}
}
}
// getCommonLength returns the common length of low and up
func getCommonLength(low []string, up []string) int {
for i := range low {
if low[i] != up[i] {
return i
}
}
return len(low)
}
// buildBetweenClause build clause in a specified table range.
// the result where clause will be low <= quotaCols < up. In other words, [low, up)
func buildBetweenClause(buf *bytes.Buffer, quotaCols []string, low []string, up []string) {
singleBetween := func(writeEqual bool) {
buf.WriteString(quotaCols[0])
buf.WriteByte(greater)
if writeEqual {
buf.WriteByte(equal)
}
buf.WriteString(low[0])
buf.WriteString(" and ")
buf.WriteString(quotaCols[0])
buf.WriteByte(less)
buf.WriteString(up[0])
}
// handle special cases with common prefix
commonLen := getCommonLength(low, up)
if commonLen > 0 {
// unexpected case for low == up, return empty result
if commonLen == len(low) {
buf.WriteString("false")
return
}
for i := 0; i < commonLen; i++ {
if i > 0 {
buf.WriteString(" and ")
}
buf.WriteString(quotaCols[i])
buf.WriteByte(equal)
buf.WriteString(low[i])
}
buf.WriteString(" and(")
defer buf.WriteByte(')')
quotaCols = quotaCols[commonLen:]
low = low[commonLen:]
up = up[commonLen:]
}
// handle special cases with only one column
if len(quotaCols) == 1 {
singleBetween(true)
return
}
buf.WriteByte('(')
singleBetween(false)
buf.WriteString(")or(")
buf.WriteString(quotaCols[0])
buf.WriteByte(equal)
buf.WriteString(low[0])
buf.WriteString(" and(")
buildCompareClause(buf, quotaCols[1:], low[1:], greater, true)
buf.WriteString("))or(")
buf.WriteString(quotaCols[0])
buf.WriteByte(equal)
buf.WriteString(up[0])
buf.WriteString(" and(")
buildCompareClause(buf, quotaCols[1:], up[1:], less, false)
buf.WriteString("))")
}
func buildOrderByClauseString(handleColNames []string) string {
if len(handleColNames) == 0 {
return ""
}
separator := ","
quotaCols := make([]string, len(handleColNames))
for i, col := range handleColNames {
quotaCols[i] = fmt.Sprintf("`%s`", escapeString(col))
}
return fmt.Sprintf("ORDER BY %s", strings.Join(quotaCols, separator))
}
func buildLockTablesSQL(allTables DatabaseTables, blockList map[string]map[string]interface{}) string {
// ,``.`` READ has 11 bytes, "LOCK TABLE" has 10 bytes
estimatedCap := len(allTables)*11 + 10
s := bytes.NewBuffer(make([]byte, 0, estimatedCap))
n := false
for dbName, tables := range allTables {
escapedDBName := escapeString(dbName)
for _, table := range tables {
// Lock views will lock related tables. However, we won't dump data only the create sql of view, so we needn't lock view here.
// Besides, mydumper also only lock base table here. https://github.com/maxbube/mydumper/blob/1fabdf87e3007e5934227b504ad673ba3697946c/mydumper.c#L1568
if table.Type != TableTypeBase {
continue
}
if blockTable, ok := blockList[dbName]; ok {
if _, ok := blockTable[table.Name]; ok {
continue
}
}
if !n {
fmt.Fprintf(s, "LOCK TABLES `%s`.`%s` READ", escapedDBName, escapeString(table.Name))
n = true
} else {
fmt.Fprintf(s, ",`%s`.`%s` READ", escapedDBName, escapeString(table.Name))
}
}
}
return s.String()
}
type oneStrColumnTable struct {
data []string
}
func (o *oneStrColumnTable) handleOneRow(rows *sql.Rows) error {
var str string
if err := rows.Scan(&str); err != nil {
return errors.Trace(err)
}
o.data = append(o.data, str)
return nil
}
func simpleQuery(conn *sql.Conn, query string, handleOneRow func(*sql.Rows) error) error {
return simpleQueryWithArgs(context.Background(), conn, handleOneRow, query)
}
func simpleQueryWithArgs(ctx context.Context, conn *sql.Conn, handleOneRow func(*sql.Rows) error, query string, args ...interface{}) error {
var (
rows *sql.Rows
err error
)
if len(args) > 0 {
rows, err = conn.QueryContext(ctx, query, args...)
} else {
rows, err = conn.QueryContext(ctx, query)
}
if err != nil {
return errors.Annotatef(err, "sql: %s, args: %s", query, args)
}
defer rows.Close()
for rows.Next() {
if err := handleOneRow(rows); err != nil {
rows.Close()
return errors.Annotatef(err, "sql: %s, args: %s", query, args)
}
}
return errors.Annotatef(rows.Err(), "sql: %s, args: %s", query, args)
}
func pickupPossibleField(tctx *tcontext.Context, meta TableMeta, db *BaseConn) (string, error) {
// try using _tidb_rowid first
if meta.HasImplicitRowID() {
return "_tidb_rowid", nil
}
// try to use pk or uk
fieldName, err := getNumericIndex(tctx, db, meta)
if err != nil {
return "", err
}
// if fieldName == "", there is no proper index
return fieldName, nil
}
func estimateCount(tctx *tcontext.Context, dbName, tableName string, db *BaseConn, field string, conf *Config) uint64 {
var query string
if strings.TrimSpace(field) == "*" || strings.TrimSpace(field) == "" {
query = fmt.Sprintf("EXPLAIN SELECT * FROM `%s`.`%s`", escapeString(dbName), escapeString(tableName))
} else {
query = fmt.Sprintf("EXPLAIN SELECT `%s` FROM `%s`.`%s`", escapeString(field), escapeString(dbName), escapeString(tableName))
}
if conf.Where != "" {
query += " WHERE "
query += conf.Where
}
estRows := detectEstimateRows(tctx, db, query, []string{"rows", "estRows", "count"})
/* tidb results field name is estRows (before 4.0.0-beta.2: count)
+-----------------------+----------+-----------+---------------------------------------------------------+
| id | estRows | task | access object | operator info |
+-----------------------+----------+-----------+---------------------------------------------------------+
| tablereader_5 | 10000.00 | root | | data:tablefullscan_4 |
| └─tablefullscan_4 | 10000.00 | cop[tikv] | table:a | table:a, keep order:false, stats:pseudo |
+-----------------------+----------+-----------+----------------------------------------------------------
mariadb result field name is rows
+------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+
| id | select_type | table | type | possible_keys | key | key_len | ref | rows | Extra |
+------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+
| 1 | SIMPLE | sbtest1 | index | NULL | k_1 | 4 | NULL | 15000049 | Using index |
+------+-------------+---------+-------+---------------+------+---------+------+----------+-------------+
mysql result field name is rows
+----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+
| id | select_type | table | partitions | type | possible_keys | key | key_len | ref | rows | filtered | Extra |
+----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+
| 1 | SIMPLE | t1 | NULL | index | NULL | multi_col | 10 | NULL | 5 | 100.00 | Using index |
+----+-------------+-------+------------+-------+---------------+-----------+---------+------+------+----------+-------------+
*/
if estRows > 0 {
return estRows
}
return 0
}
func detectEstimateRows(tctx *tcontext.Context, db *BaseConn, query string, fieldNames []string) uint64 {
var (
fieldIndex int
oneRow []sql.NullString
)
err := db.QuerySQL(tctx, func(rows *sql.Rows) error {
columns, err := rows.Columns()
if err != nil {
return errors.Trace(err)
}
addr := make([]interface{}, len(columns))
oneRow = make([]sql.NullString, len(columns))
fieldIndex = -1
found:
for i := range oneRow {
for _, fieldName := range fieldNames {
if strings.EqualFold(columns[i], fieldName) {
fieldIndex = i
break found
}
}
}
if fieldIndex == -1 {
rows.Close()
return nil
}
for i := range oneRow {
addr[i] = &oneRow[i]
}
return rows.Scan(addr...)
}, func() {}, query)
if err != nil || fieldIndex == -1 {
tctx.L().Info("can't estimate rows from db",
zap.String("query", query), zap.Int("fieldIndex", fieldIndex), log.ShortError(err))
return 0
}
estRows, err := strconv.ParseFloat(oneRow[fieldIndex].String, 64)
if err != nil {
tctx.L().Info("can't get parse estimate rows from db",
zap.String("query", query), zap.String("estRows", oneRow[fieldIndex].String), log.ShortError(err))
return 0
}
return uint64(estRows)
}
func parseSnapshotToTSO(pool *sql.DB, snapshot string) (uint64, error) {
snapshotTS, err := strconv.ParseUint(snapshot, 10, 64)
if err == nil {
return snapshotTS, nil
}
var tso sql.NullInt64
query := "SELECT unix_timestamp(?)"
row := pool.QueryRow(query, snapshot)
err = row.Scan(&tso)
if err != nil {
return 0, errors.Annotatef(err, "sql: %s", strings.ReplaceAll(query, "?", fmt.Sprintf(`"%s"`, snapshot)))
}
if !tso.Valid {
return 0, errors.Errorf("snapshot %s format not supported. please use tso or '2006-01-02 15:04:05' format time", snapshot)
}
return (uint64(tso.Int64) << 18) * 1000, nil
}
func buildWhereCondition(conf *Config, where string) string {
var query strings.Builder
separator := "WHERE"
leftBracket := " "
rightBracket := " "
if conf.Where != "" && where != "" {
leftBracket = " ("
rightBracket = ") "
}
if conf.Where != "" {
query.WriteString(separator)
query.WriteString(leftBracket)
query.WriteString(conf.Where)
query.WriteString(rightBracket)
separator = "AND"
}
if where != "" {
query.WriteString(separator)
query.WriteString(leftBracket)
query.WriteString(where)
query.WriteString(rightBracket)
}
return query.String()
}
func escapeString(s string) string {
return strings.ReplaceAll(s, "`", "``")
}
// GetPartitionNames get partition names from a specified table
func GetPartitionNames(tctx *tcontext.Context, db *BaseConn, schema, table string) (partitions []string, err error) {
partitions = make([]string, 0)
var partitionName sql.NullString
err = db.QuerySQL(tctx, func(rows *sql.Rows) error {
err := rows.Scan(&partitionName)
if err != nil {
return errors.Trace(err)
}
if partitionName.Valid {
partitions = append(partitions, partitionName.String)
}
return nil
}, func() {
partitions = partitions[:0]
}, "SELECT PARTITION_NAME from INFORMATION_SCHEMA.PARTITIONS WHERE TABLE_SCHEMA = ? AND TABLE_NAME = ?", schema, table)
return
}
// GetPartitionTableIDs get partition tableIDs through histograms.
// SHOW STATS_HISTOGRAMS has db_name,table_name,partition_name but doesn't have partition id
// mysql.stats_histograms has partition_id but doesn't have db_name,table_name,partition_name
// So we combine the results from these two sqls to get partition ids for each table
// If UPDATE_TIME,DISTINCT_COUNT are equal, we assume these two records can represent one line.
// If histograms are not accurate or (UPDATE_TIME,DISTINCT_COUNT) has duplicate data, it's still fine.
// Because the possibility is low and the effect is that we will select more than one regions in one time,
// this will not affect the correctness of the dumping data and will not affect the memory usage much.
// This method is tricky, but no better way is found.
// Because TiDB v3.0.0's information_schema.partition table doesn't have partition name or partition id info
// return (dbName -> tbName -> partitionName -> partitionID, error)
func GetPartitionTableIDs(db *sql.Conn, tables map[string]map[string]struct{}) (map[string]map[string]map[string]int64, error) {
const (
showStatsHistogramsSQL = "SHOW STATS_HISTOGRAMS"
selectStatsHistogramsSQL = "SELECT TABLE_ID,FROM_UNIXTIME(VERSION DIV 262144 DIV 1000,'%Y-%m-%d %H:%i:%s') AS UPDATE_TIME,DISTINCT_COUNT FROM mysql.stats_histograms"
)
partitionIDs := make(map[string]map[string]map[string]int64, len(tables))
rows, err := db.QueryContext(context.Background(), showStatsHistogramsSQL)
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", showStatsHistogramsSQL)
}
results, err := GetSpecifiedColumnValuesAndClose(rows, "DB_NAME", "TABLE_NAME", "PARTITION_NAME", "UPDATE_TIME", "DISTINCT_COUNT")
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", showStatsHistogramsSQL)
}
type partitionInfo struct {
dbName, tbName, partitionName string
}
saveMap := make(map[string]map[string]partitionInfo)
for _, oneRow := range results {
dbName, tbName, partitionName, updateTime, distinctCount := oneRow[0], oneRow[1], oneRow[2], oneRow[3], oneRow[4]
if len(partitionName) == 0 {
continue
}
if tbm, ok := tables[dbName]; ok {
if _, ok = tbm[tbName]; ok {
if _, ok = saveMap[updateTime]; !ok {
saveMap[updateTime] = make(map[string]partitionInfo)
}
saveMap[updateTime][distinctCount] = partitionInfo{
dbName: dbName,
tbName: tbName,
partitionName: partitionName,
}
}
}
}
if len(saveMap) == 0 {
return map[string]map[string]map[string]int64{}, nil
}
err = simpleQuery(db, selectStatsHistogramsSQL, func(rows *sql.Rows) error {
var (
tableID int64
updateTime, distinctCount string
)
err2 := rows.Scan(&tableID, &updateTime, &distinctCount)
if err2 != nil {
return errors.Trace(err2)
}
if mpt, ok := saveMap[updateTime]; ok {
if partition, ok := mpt[distinctCount]; ok {
dbName, tbName, partitionName := partition.dbName, partition.tbName, partition.partitionName
if _, ok := partitionIDs[dbName]; !ok {
partitionIDs[dbName] = make(map[string]map[string]int64)
}
if _, ok := partitionIDs[dbName][tbName]; !ok {
partitionIDs[dbName][tbName] = make(map[string]int64)
}
partitionIDs[dbName][tbName][partitionName] = tableID
}
}
return nil
})
return partitionIDs, err
}
// GetDBInfo get model.DBInfos from database sql interface.
// We need table_id to check whether a region belongs to this table
func GetDBInfo(db *sql.Conn, tables map[string]map[string]struct{}) ([]*model.DBInfo, error) {
const tableIDSQL = "SELECT TABLE_SCHEMA,TABLE_NAME,TIDB_TABLE_ID FROM INFORMATION_SCHEMA.TABLES ORDER BY TABLE_SCHEMA"
schemas := make([]*model.DBInfo, 0, len(tables))
var (
tableSchema, tableName string
tidbTableID int64
)
partitionIDs, err := GetPartitionTableIDs(db, tables)
if err != nil {
return nil, err
}
err = simpleQuery(db, tableIDSQL, func(rows *sql.Rows) error {
err2 := rows.Scan(&tableSchema, &tableName, &tidbTableID)
if err2 != nil {
return errors.Trace(err2)
}
if tbm, ok := tables[tableSchema]; !ok {
return nil
} else if _, ok = tbm[tableName]; !ok {
return nil
}
last := len(schemas) - 1
if last < 0 || schemas[last].Name.O != tableSchema {
schemas = append(schemas, &model.DBInfo{
Name: model.CIStr{O: tableSchema},
Tables: make([]*model.TableInfo, 0, len(tables[tableSchema])),
})
last++
}
var partition *model.PartitionInfo
if tbm, ok := partitionIDs[tableSchema]; ok {
if ptm, ok := tbm[tableName]; ok {
partition = &model.PartitionInfo{Definitions: make([]model.PartitionDefinition, 0, len(ptm))}
for partitionName, partitionID := range ptm {
partition.Definitions = append(partition.Definitions, model.PartitionDefinition{
ID: partitionID,
Name: model.CIStr{O: partitionName},
})
}
}
}
schemas[last].Tables = append(schemas[last].Tables, &model.TableInfo{
ID: tidbTableID,
Name: model.CIStr{O: tableName},
Partition: partition,
})
return nil
})
return schemas, err
}
// GetRegionInfos get region info including regionID, start key, end key from database sql interface.
// start key, end key includes information to help split table
func GetRegionInfos(db *sql.Conn) (*helper.RegionsInfo, error) {
const tableRegionSQL = "SELECT REGION_ID,START_KEY,END_KEY FROM INFORMATION_SCHEMA.TIKV_REGION_STATUS ORDER BY START_KEY;"
var (
regionID int64
startKey, endKey string
)
regionsInfo := &helper.RegionsInfo{Regions: make([]helper.RegionInfo, 0)}
err := simpleQuery(db, tableRegionSQL, func(rows *sql.Rows) error {
err := rows.Scan(®ionID, &startKey, &endKey)
if err != nil {
return errors.Trace(err)
}
regionsInfo.Regions = append(regionsInfo.Regions, helper.RegionInfo{
ID: regionID,
StartKey: startKey,
EndKey: endKey,
})
return nil
})
return regionsInfo, err
}
// GetCharsetAndDefaultCollation gets charset and default collation map.
func GetCharsetAndDefaultCollation(ctx context.Context, db *sql.Conn) (map[string]string, error) {
charsetAndDefaultCollation := make(map[string]string)
query := "SHOW CHARACTER SET"
// Show an example.
/*
mysql> SHOW CHARACTER SET;
+----------+---------------------------------+---------------------+--------+
| Charset | Description | Default collation | Maxlen |
+----------+---------------------------------+---------------------+--------+
| armscii8 | ARMSCII-8 Armenian | armscii8_general_ci | 1 |
| ascii | US ASCII | ascii_general_ci | 1 |
| big5 | Big5 Traditional Chinese | big5_chinese_ci | 2 |
| binary | Binary pseudo charset | binary | 1 |
| cp1250 | Windows Central European | cp1250_general_ci | 1 |
| cp1251 | Windows Cyrillic | cp1251_general_ci | 1 |
+----------+---------------------------------+---------------------+--------+
*/
rows, err := db.QueryContext(ctx, query)
if err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
defer rows.Close()
for rows.Next() {
var charset, description, collation string
var maxlen int
if scanErr := rows.Scan(&charset, &description, &collation, &maxlen); scanErr != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
charsetAndDefaultCollation[strings.ToLower(charset)] = collation
}
if err = rows.Close(); err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
if err = rows.Err(); err != nil {
return nil, errors.Annotatef(err, "sql: %s", query)
}
return charsetAndDefaultCollation, err
}
相关信息
相关文章
0
赞
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦