hadoop DBInputFormat 源码

  • 2022-10-20
  • 浏览 (140)

haddop DBInputFormat 代码

文件路径:/hadoop-mapreduce-project/hadoop-mapreduce-client/hadoop-mapreduce-client-core/src/main/java/org/apache/hadoop/mapreduce/lib/db/DBInputFormat.java

/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.mapreduce.lib.db;

import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.ArrayList;
import java.util.List;

import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;
import org.apache.hadoop.conf.Configurable;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.mapreduce.InputFormat;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.Job;
import org.apache.hadoop.mapreduce.JobContext;
import org.apache.hadoop.mapreduce.MRJobConfig;
import org.apache.hadoop.mapreduce.RecordReader;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.hadoop.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * A InputFormat that reads input data from an SQL table.
 * <p>
 * DBInputFormat emits LongWritables containing the record number as 
 * key and DBWritables as value. 
 * 
 * The SQL query, and input class can be using one of the two 
 * setInput methods.
 */
@InterfaceAudience.Public
@InterfaceStability.Stable
public class DBInputFormat<T extends DBWritable>
    extends InputFormat<LongWritable, T> implements Configurable {

  private static final Logger LOG =
      LoggerFactory.getLogger(DBInputFormat.class);
  
  protected String dbProductName = "DEFAULT";

  /**
   * A Class that does nothing, implementing DBWritable
   */
  @InterfaceStability.Evolving
  public static class NullDBWritable implements DBWritable, Writable {
    @Override
    public void readFields(DataInput in) throws IOException { }
    @Override
    public void readFields(ResultSet arg0) throws SQLException { }
    @Override
    public void write(DataOutput out) throws IOException { }
    @Override
    public void write(PreparedStatement arg0) throws SQLException { }
  }
  
  /**
   * A InputSplit that spans a set of rows
   */
  @InterfaceStability.Evolving
  public static class DBInputSplit extends InputSplit implements Writable {

    private long end = 0;
    private long start = 0;

    /**
     * Default Constructor
     */
    public DBInputSplit() {
    }

    /**
     * Convenience Constructor
     * @param start the index of the first row to select
     * @param end the index of the last row to select
     */
    public DBInputSplit(long start, long end) {
      this.start = start;
      this.end = end;
    }

    /** {@inheritDoc} */
    public String[] getLocations() throws IOException {
      // TODO Add a layer to enable SQL "sharding" and support locality
      return new String[] {};
    }

    /**
     * @return The index of the first row to select
     */
    public long getStart() {
      return start;
    }

    /**
     * @return The index of the last row to select
     */
    public long getEnd() {
      return end;
    }

    /**
     * @return The total row count in this split
     */
    public long getLength() throws IOException {
      return end - start;
    }

    /** {@inheritDoc} */
    public void readFields(DataInput input) throws IOException {
      start = input.readLong();
      end = input.readLong();
    }

    /** {@inheritDoc} */
    public void write(DataOutput output) throws IOException {
      output.writeLong(start);
      output.writeLong(end);
    }
  }

  protected String conditions;

  protected Connection connection;

  protected String tableName;

  protected String[] fieldNames;

  protected DBConfiguration dbConf;

  /** {@inheritDoc} */
  public void setConf(Configuration conf) {

    dbConf = new DBConfiguration(conf);

    try {
      this.connection = createConnection();

      DatabaseMetaData dbMeta = connection.getMetaData();
      this.dbProductName =
          StringUtils.toUpperCase(dbMeta.getDatabaseProductName());
    }
    catch (Exception ex) {
      throw new RuntimeException(ex);
    }

    tableName = dbConf.getInputTableName();
    fieldNames = dbConf.getInputFieldNames();
    conditions = dbConf.getInputConditions();
  }

  public Configuration getConf() {
    return dbConf.getConf();
  }
  
  public DBConfiguration getDBConf() {
    return dbConf;
  }

  public Connection getConnection() {
    // TODO Remove this code that handles backward compatibility.
    if (this.connection == null) {
      this.connection = createConnection();
    }

    return this.connection;
  }

  public Connection createConnection() {
    try {
      Connection newConnection = dbConf.getConnection();
      newConnection.setAutoCommit(false);
      newConnection.setTransactionIsolation(
          Connection.TRANSACTION_SERIALIZABLE);

      return newConnection;
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }

  public String getDBProductName() {
    return dbProductName;
  }

  protected RecordReader<LongWritable, T> createDBRecordReader(DBInputSplit split,
      Configuration conf) throws IOException {

    @SuppressWarnings("unchecked")
    Class<T> inputClass = (Class<T>) (dbConf.getInputClass());
    try {
      // use database product name to determine appropriate record reader.
      if (dbProductName.startsWith("ORACLE")) {
        // use Oracle-specific db reader.
        return new OracleDBRecordReader<T>(split, inputClass,
            conf, createConnection(), getDBConf(), conditions, fieldNames,
            tableName);
      } else if (dbProductName.startsWith("MYSQL")) {
        // use MySQL-specific db reader.
        return new MySQLDBRecordReader<T>(split, inputClass,
            conf, createConnection(), getDBConf(), conditions, fieldNames,
            tableName);
      } else {
        // Generic reader.
        return new DBRecordReader<T>(split, inputClass,
            conf, createConnection(), getDBConf(), conditions, fieldNames,
            tableName);
      }
    } catch (SQLException ex) {
      throw new IOException(ex.getMessage());
    }
  }

  /** {@inheritDoc} */
  public RecordReader<LongWritable, T> createRecordReader(InputSplit split,
      TaskAttemptContext context) throws IOException, InterruptedException {  

    return createDBRecordReader((DBInputSplit) split, context.getConfiguration());
  }

  /** {@inheritDoc} */
  public List<InputSplit> getSplits(JobContext job) throws IOException {

    ResultSet results = null;  
    Statement statement = null;
    try {
      statement = connection.createStatement();

      results = statement.executeQuery(getCountQuery());
      results.next();

      long count = results.getLong(1);
      int chunks = job.getConfiguration().getInt(MRJobConfig.NUM_MAPS, 1);
      long chunkSize = (count / chunks);

      results.close();
      statement.close();

      List<InputSplit> splits = new ArrayList<InputSplit>();

      // Split the rows into n-number of chunks and adjust the last chunk
      // accordingly
      for (int i = 0; i < chunks; i++) {
        DBInputSplit split;

        if ((i + 1) == chunks)
          split = new DBInputSplit(i * chunkSize, count);
        else
          split = new DBInputSplit(i * chunkSize, (i * chunkSize)
              + chunkSize);

        splits.add(split);
      }

      connection.commit();
      return splits;
    } catch (SQLException e) {
      throw new IOException("Got SQLException", e);
    } finally {
      try {
        if (results != null) { results.close(); }
      } catch (SQLException e1) {}
      try {
        if (statement != null) { statement.close(); }
      } catch (SQLException e1) {}

      closeConnection();
    }
  }

  /** Returns the query for getting the total number of rows, 
   * subclasses can override this for custom behaviour.*/
  protected String getCountQuery() {
    
    if(dbConf.getInputCountQuery() != null) {
      return dbConf.getInputCountQuery();
    }
    
    StringBuilder query = new StringBuilder();
    query.append("SELECT COUNT(*) FROM " + tableName);

    if (conditions != null && conditions.length() > 0)
      query.append(" WHERE " + conditions);
    return query.toString();
  }

  /**
   * Initializes the map-part of the job with the appropriate input settings.
   * 
   * @param job The map-reduce job
   * @param inputClass the class object implementing DBWritable, which is the 
   * Java object holding tuple fields.
   * @param tableName The table to read data from
   * @param conditions The condition which to select data with, 
   * eg. '(updated &gt; 20070101 AND length &gt; 0)'
   * @param orderBy the fieldNames in the orderBy clause.
   * @param fieldNames The field names in the table
   * @see #setInput(Job, Class, String, String)
   */
  public static void setInput(Job job, 
      Class<? extends DBWritable> inputClass,
      String tableName,String conditions, 
      String orderBy, String... fieldNames) {
    job.setInputFormatClass(DBInputFormat.class);
    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
    dbConf.setInputClass(inputClass);
    dbConf.setInputTableName(tableName);
    dbConf.setInputFieldNames(fieldNames);
    dbConf.setInputConditions(conditions);
    dbConf.setInputOrderBy(orderBy);
  }
  
  /**
   * Initializes the map-part of the job with the appropriate input settings.
   * 
   * @param job The map-reduce job
   * @param inputClass the class object implementing DBWritable, which is the 
   * Java object holding tuple fields.
   * @param inputQuery the input query to select fields. Example : 
   * "SELECT f1, f2, f3 FROM Mytable ORDER BY f1"
   * @param inputCountQuery the input query that returns 
   * the number of records in the table. 
   * Example : "SELECT COUNT(f1) FROM Mytable"
   * @see #setInput(Job, Class, String, String, String, String...)
   */
  public static void setInput(Job job,
      Class<? extends DBWritable> inputClass,
      String inputQuery, String inputCountQuery) {
    job.setInputFormatClass(DBInputFormat.class);
    DBConfiguration dbConf = new DBConfiguration(job.getConfiguration());
    dbConf.setInputClass(inputClass);
    dbConf.setInputQuery(inputQuery);
    dbConf.setInputCountQuery(inputCountQuery);
  }

  protected void closeConnection() {
    try {
      if (null != this.connection) {
        this.connection.close();
        this.connection = null;
      }
    } catch (SQLException sqlE) {
      LOG.debug("Exception on close", sqlE);
    }
  }
}

相关信息

hadoop 源码目录

相关文章

hadoop BigDecimalSplitter 源码

hadoop BooleanSplitter 源码

hadoop DBConfiguration 源码

hadoop DBOutputFormat 源码

hadoop DBRecordReader 源码

hadoop DBSplitter 源码

hadoop DBWritable 源码

hadoop DataDrivenDBInputFormat 源码

hadoop DataDrivenDBRecordReader 源码

hadoop DateSplitter 源码

0  赞