wmi.go 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. // +build windows
  2. /*
  3. Package wmi provides a WQL interface for WMI on Windows.
  4. Example code to print names of running processes:
  5. type Win32_Process struct {
  6. Name string
  7. }
  8. func main() {
  9. var dst []Win32_Process
  10. q := wmi.CreateQuery(&dst, "")
  11. err := wmi.Query(q, &dst)
  12. if err != nil {
  13. log.Fatal(err)
  14. }
  15. for i, v := range dst {
  16. println(i, v.Name)
  17. }
  18. }
  19. */
  20. package wmi
  21. import (
  22. "bytes"
  23. "errors"
  24. "fmt"
  25. "log"
  26. "os"
  27. "reflect"
  28. "runtime"
  29. "strconv"
  30. "strings"
  31. "sync"
  32. "time"
  33. "github.com/go-ole/go-ole"
  34. "github.com/go-ole/go-ole/oleutil"
  35. )
  36. var l = log.New(os.Stdout, "", log.LstdFlags)
  37. var (
  38. ErrInvalidEntityType = errors.New("wmi: invalid entity type")
  39. // ErrNilCreateObject is the error returned if CreateObject returns nil even
  40. // if the error was nil.
  41. ErrNilCreateObject = errors.New("wmi: create object returned nil")
  42. lock sync.Mutex
  43. )
  44. // S_FALSE is returned by CoInitializeEx if it was already called on this thread.
  45. const S_FALSE = 0x00000001
  46. // QueryNamespace invokes Query with the given namespace on the local machine.
  47. func QueryNamespace(query string, dst interface{}, namespace string) error {
  48. return Query(query, dst, nil, namespace)
  49. }
  50. // Query runs the WQL query and appends the values to dst.
  51. //
  52. // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
  53. // the query must have the same name in dst. Supported types are all signed and
  54. // unsigned integers, time.Time, string, bool, or a pointer to one of those.
  55. // Array types are not supported.
  56. //
  57. // By default, the local machine and default namespace are used. These can be
  58. // changed using connectServerArgs. See
  59. // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
  60. //
  61. // Query is a wrapper around DefaultClient.Query.
  62. func Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
  63. if DefaultClient.SWbemServicesClient == nil {
  64. return DefaultClient.Query(query, dst, connectServerArgs...)
  65. }
  66. return DefaultClient.SWbemServicesClient.Query(query, dst, connectServerArgs...)
  67. }
  68. // A Client is an WMI query client.
  69. //
  70. // Its zero value (DefaultClient) is a usable client.
  71. type Client struct {
  72. // NonePtrZero specifies if nil values for fields which aren't pointers
  73. // should be returned as the field types zero value.
  74. //
  75. // Setting this to true allows stucts without pointer fields to be used
  76. // without the risk failure should a nil value returned from WMI.
  77. NonePtrZero bool
  78. // PtrNil specifies if nil values for pointer fields should be returned
  79. // as nil.
  80. //
  81. // Setting this to true will set pointer fields to nil where WMI
  82. // returned nil, otherwise the types zero value will be returned.
  83. PtrNil bool
  84. // AllowMissingFields specifies that struct fields not present in the
  85. // query result should not result in an error.
  86. //
  87. // Setting this to true allows custom queries to be used with full
  88. // struct definitions instead of having to define multiple structs.
  89. AllowMissingFields bool
  90. // SWbemServiceClient is an optional SWbemServices object that can be
  91. // initialized and then reused across multiple queries. If it is null
  92. // then the method will initialize a new temporary client each time.
  93. SWbemServicesClient *SWbemServices
  94. }
  95. // DefaultClient is the default Client and is used by Query, QueryNamespace
  96. var DefaultClient = &Client{}
  97. // Query runs the WQL query and appends the values to dst.
  98. //
  99. // dst must have type *[]S or *[]*S, for some struct type S. Fields selected in
  100. // the query must have the same name in dst. Supported types are all signed and
  101. // unsigned integers, time.Time, string, bool, or a pointer to one of those.
  102. // Array types are not supported.
  103. //
  104. // By default, the local machine and default namespace are used. These can be
  105. // changed using connectServerArgs. See
  106. // http://msdn.microsoft.com/en-us/library/aa393720.aspx for details.
  107. func (c *Client) Query(query string, dst interface{}, connectServerArgs ...interface{}) error {
  108. dv := reflect.ValueOf(dst)
  109. if dv.Kind() != reflect.Ptr || dv.IsNil() {
  110. return ErrInvalidEntityType
  111. }
  112. dv = dv.Elem()
  113. mat, elemType := checkMultiArg(dv)
  114. if mat == multiArgTypeInvalid {
  115. return ErrInvalidEntityType
  116. }
  117. lock.Lock()
  118. defer lock.Unlock()
  119. runtime.LockOSThread()
  120. defer runtime.UnlockOSThread()
  121. err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED)
  122. if err != nil {
  123. oleCode := err.(*ole.OleError).Code()
  124. if oleCode != ole.S_OK && oleCode != S_FALSE {
  125. return err
  126. }
  127. }
  128. defer ole.CoUninitialize()
  129. unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator")
  130. if err != nil {
  131. return err
  132. } else if unknown == nil {
  133. return ErrNilCreateObject
  134. }
  135. defer unknown.Release()
  136. wmi, err := unknown.QueryInterface(ole.IID_IDispatch)
  137. if err != nil {
  138. return err
  139. }
  140. defer wmi.Release()
  141. // service is a SWbemServices
  142. serviceRaw, err := oleutil.CallMethod(wmi, "ConnectServer", connectServerArgs...)
  143. if err != nil {
  144. return err
  145. }
  146. service := serviceRaw.ToIDispatch()
  147. defer serviceRaw.Clear()
  148. // result is a SWBemObjectSet
  149. resultRaw, err := oleutil.CallMethod(service, "ExecQuery", query)
  150. if err != nil {
  151. return err
  152. }
  153. result := resultRaw.ToIDispatch()
  154. defer resultRaw.Clear()
  155. count, err := oleInt64(result, "Count")
  156. if err != nil {
  157. return err
  158. }
  159. enumProperty, err := result.GetProperty("_NewEnum")
  160. if err != nil {
  161. return err
  162. }
  163. defer enumProperty.Clear()
  164. enum, err := enumProperty.ToIUnknown().IEnumVARIANT(ole.IID_IEnumVariant)
  165. if err != nil {
  166. return err
  167. }
  168. if enum == nil {
  169. return fmt.Errorf("can't get IEnumVARIANT, enum is nil")
  170. }
  171. defer enum.Release()
  172. // Initialize a slice with Count capacity
  173. dv.Set(reflect.MakeSlice(dv.Type(), 0, int(count)))
  174. var errFieldMismatch error
  175. for itemRaw, length, err := enum.Next(1); length > 0; itemRaw, length, err = enum.Next(1) {
  176. if err != nil {
  177. return err
  178. }
  179. err := func() error {
  180. // item is a SWbemObject, but really a Win32_Process
  181. item := itemRaw.ToIDispatch()
  182. defer item.Release()
  183. ev := reflect.New(elemType)
  184. if err = c.loadEntity(ev.Interface(), item); err != nil {
  185. if _, ok := err.(*ErrFieldMismatch); ok {
  186. // We continue loading entities even in the face of field mismatch errors.
  187. // If we encounter any other error, that other error is returned. Otherwise,
  188. // an ErrFieldMismatch is returned.
  189. errFieldMismatch = err
  190. } else {
  191. return err
  192. }
  193. }
  194. if mat != multiArgTypeStructPtr {
  195. ev = ev.Elem()
  196. }
  197. dv.Set(reflect.Append(dv, ev))
  198. return nil
  199. }()
  200. if err != nil {
  201. return err
  202. }
  203. }
  204. return errFieldMismatch
  205. }
  206. // ErrFieldMismatch is returned when a field is to be loaded into a different
  207. // type than the one it was stored from, or when a field is missing or
  208. // unexported in the destination struct.
  209. // StructType is the type of the struct pointed to by the destination argument.
  210. type ErrFieldMismatch struct {
  211. StructType reflect.Type
  212. FieldName string
  213. Reason string
  214. }
  215. func (e *ErrFieldMismatch) Error() string {
  216. return fmt.Sprintf("wmi: cannot load field %q into a %q: %s",
  217. e.FieldName, e.StructType, e.Reason)
  218. }
  219. var timeType = reflect.TypeOf(time.Time{})
  220. // loadEntity loads a SWbemObject into a struct pointer.
  221. func (c *Client) loadEntity(dst interface{}, src *ole.IDispatch) (errFieldMismatch error) {
  222. v := reflect.ValueOf(dst).Elem()
  223. for i := 0; i < v.NumField(); i++ {
  224. f := v.Field(i)
  225. of := f
  226. isPtr := f.Kind() == reflect.Ptr
  227. if isPtr {
  228. ptr := reflect.New(f.Type().Elem())
  229. f.Set(ptr)
  230. f = f.Elem()
  231. }
  232. n := v.Type().Field(i).Name
  233. if !f.CanSet() {
  234. return &ErrFieldMismatch{
  235. StructType: of.Type(),
  236. FieldName: n,
  237. Reason: "CanSet() is false",
  238. }
  239. }
  240. prop, err := oleutil.GetProperty(src, n)
  241. if err != nil {
  242. if !c.AllowMissingFields {
  243. errFieldMismatch = &ErrFieldMismatch{
  244. StructType: of.Type(),
  245. FieldName: n,
  246. Reason: "no such struct field",
  247. }
  248. }
  249. continue
  250. }
  251. defer prop.Clear()
  252. if prop.Value() == nil {
  253. continue
  254. }
  255. switch val := prop.Value().(type) {
  256. case int8, int16, int32, int64, int:
  257. v := reflect.ValueOf(val).Int()
  258. switch f.Kind() {
  259. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  260. f.SetInt(v)
  261. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  262. f.SetUint(uint64(v))
  263. default:
  264. return &ErrFieldMismatch{
  265. StructType: of.Type(),
  266. FieldName: n,
  267. Reason: "not an integer class",
  268. }
  269. }
  270. case uint8, uint16, uint32, uint64:
  271. v := reflect.ValueOf(val).Uint()
  272. switch f.Kind() {
  273. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  274. f.SetInt(int64(v))
  275. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  276. f.SetUint(v)
  277. default:
  278. return &ErrFieldMismatch{
  279. StructType: of.Type(),
  280. FieldName: n,
  281. Reason: "not an integer class",
  282. }
  283. }
  284. case string:
  285. switch f.Kind() {
  286. case reflect.String:
  287. f.SetString(val)
  288. case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
  289. iv, err := strconv.ParseInt(val, 10, 64)
  290. if err != nil {
  291. return err
  292. }
  293. f.SetInt(iv)
  294. case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
  295. uv, err := strconv.ParseUint(val, 10, 64)
  296. if err != nil {
  297. return err
  298. }
  299. f.SetUint(uv)
  300. case reflect.Struct:
  301. switch f.Type() {
  302. case timeType:
  303. if len(val) == 25 {
  304. mins, err := strconv.Atoi(val[22:])
  305. if err != nil {
  306. return err
  307. }
  308. val = val[:22] + fmt.Sprintf("%02d%02d", mins/60, mins%60)
  309. }
  310. t, err := time.Parse("20060102150405.000000-0700", val)
  311. if err != nil {
  312. return err
  313. }
  314. f.Set(reflect.ValueOf(t))
  315. }
  316. }
  317. case bool:
  318. switch f.Kind() {
  319. case reflect.Bool:
  320. f.SetBool(val)
  321. default:
  322. return &ErrFieldMismatch{
  323. StructType: of.Type(),
  324. FieldName: n,
  325. Reason: "not a bool",
  326. }
  327. }
  328. case float32:
  329. switch f.Kind() {
  330. case reflect.Float32:
  331. f.SetFloat(float64(val))
  332. default:
  333. return &ErrFieldMismatch{
  334. StructType: of.Type(),
  335. FieldName: n,
  336. Reason: "not a Float32",
  337. }
  338. }
  339. default:
  340. if f.Kind() == reflect.Slice {
  341. switch f.Type().Elem().Kind() {
  342. case reflect.String:
  343. safeArray := prop.ToArray()
  344. if safeArray != nil {
  345. arr := safeArray.ToValueArray()
  346. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  347. for i, v := range arr {
  348. s := fArr.Index(i)
  349. s.SetString(v.(string))
  350. }
  351. f.Set(fArr)
  352. }
  353. case reflect.Uint8:
  354. safeArray := prop.ToArray()
  355. if safeArray != nil {
  356. arr := safeArray.ToValueArray()
  357. fArr := reflect.MakeSlice(f.Type(), len(arr), len(arr))
  358. for i, v := range arr {
  359. s := fArr.Index(i)
  360. s.SetUint(reflect.ValueOf(v).Uint())
  361. }
  362. f.Set(fArr)
  363. }
  364. default:
  365. return &ErrFieldMismatch{
  366. StructType: of.Type(),
  367. FieldName: n,
  368. Reason: fmt.Sprintf("unsupported slice type (%T)", val),
  369. }
  370. }
  371. } else {
  372. typeof := reflect.TypeOf(val)
  373. if typeof == nil && (isPtr || c.NonePtrZero) {
  374. if (isPtr && c.PtrNil) || (!isPtr && c.NonePtrZero) {
  375. of.Set(reflect.Zero(of.Type()))
  376. }
  377. break
  378. }
  379. return &ErrFieldMismatch{
  380. StructType: of.Type(),
  381. FieldName: n,
  382. Reason: fmt.Sprintf("unsupported type (%T)", val),
  383. }
  384. }
  385. }
  386. }
  387. return errFieldMismatch
  388. }
  389. type multiArgType int
  390. const (
  391. multiArgTypeInvalid multiArgType = iota
  392. multiArgTypeStruct
  393. multiArgTypeStructPtr
  394. )
  395. // checkMultiArg checks that v has type []S, []*S for some struct type S.
  396. //
  397. // It returns what category the slice's elements are, and the reflect.Type
  398. // that represents S.
  399. func checkMultiArg(v reflect.Value) (m multiArgType, elemType reflect.Type) {
  400. if v.Kind() != reflect.Slice {
  401. return multiArgTypeInvalid, nil
  402. }
  403. elemType = v.Type().Elem()
  404. switch elemType.Kind() {
  405. case reflect.Struct:
  406. return multiArgTypeStruct, elemType
  407. case reflect.Ptr:
  408. elemType = elemType.Elem()
  409. if elemType.Kind() == reflect.Struct {
  410. return multiArgTypeStructPtr, elemType
  411. }
  412. }
  413. return multiArgTypeInvalid, nil
  414. }
  415. func oleInt64(item *ole.IDispatch, prop string) (int64, error) {
  416. v, err := oleutil.GetProperty(item, prop)
  417. if err != nil {
  418. return 0, err
  419. }
  420. defer v.Clear()
  421. i := int64(v.Val)
  422. return i, nil
  423. }
  424. // CreateQuery returns a WQL query string that queries all columns of src. where
  425. // is an optional string that is appended to the query, to be used with WHERE
  426. // clauses. In such a case, the "WHERE" string should appear at the beginning.
  427. func CreateQuery(src interface{}, where string) string {
  428. var b bytes.Buffer
  429. b.WriteString("SELECT ")
  430. s := reflect.Indirect(reflect.ValueOf(src))
  431. t := s.Type()
  432. if s.Kind() == reflect.Slice {
  433. t = t.Elem()
  434. }
  435. if t.Kind() != reflect.Struct {
  436. return ""
  437. }
  438. var fields []string
  439. for i := 0; i < t.NumField(); i++ {
  440. fields = append(fields, t.Field(i).Name)
  441. }
  442. b.WriteString(strings.Join(fields, ", "))
  443. b.WriteString(" FROM ")
  444. b.WriteString(t.Name())
  445. b.WriteString(" " + where)
  446. return b.String()
  447. }